From 948d72ce32b0a1c87a146cc67e3918d4e6f07f28 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Mon, 9 Dec 2024 17:19:38 +0200 Subject: [PATCH] init model_wrapper --- nncf/common/factory.py | 20 +++++- nncf/common/model.py | 67 +++++++++++++++++++ nncf/onnx/quantization/quantize_model.py | 12 ++-- .../openvino/quantization/quantize_ifmodel.py | 5 +- nncf/openvino/quantization/quantize_model.py | 10 +-- nncf/quantization/algorithms/algorithm.py | 9 +-- .../algorithms/bias_correction/algorithm.py | 11 +-- .../algorithms/channel_alignment/algorithm.py | 6 +- .../fast_bias_correction/algorithm.py | 17 +++-- .../algorithms/min_max/algorithm.py | 23 ++++--- nncf/quantization/algorithms/pipeline.py | 61 ++++++----------- .../algorithms/post_training/algorithm.py | 3 +- .../algorithms/smooth_quant/algorithm.py | 15 +++-- nncf/torch/quantization/quantize_model.py | 11 +-- .../test_templates/test_bias_correction.py | 12 ++-- .../test_templates/test_channel_alignment.py | 3 +- .../test_fast_bias_correction.py | 7 +- .../test_templates/test_ptq_params.py | 11 ++- .../test_templates/test_smooth_quant.py | 8 +-- tests/onnx/quantization/common.py | 6 +- .../test_fq_params_calculation.py | 7 +- .../native/quantization/test_graphs.py | 7 +- .../ptq/test_calculation_quantizer_params.py | 9 ++- tests/torch/ptq/test_fq_params_calculation.py | 5 +- tests/torch/ptq/test_graphs.py | 6 +- 25 files changed, 220 insertions(+), 131 deletions(-) create mode 100644 nncf/common/model.py diff --git a/nncf/common/factory.py b/nncf/common/factory.py index c5a921c8068..6f17dfc7fbd 100644 --- a/nncf/common/factory.py +++ b/nncf/common/factory.py @@ -9,7 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TypeVar +import os +from typing import Any, Dict, Optional, Tuple, TypeVar import nncf from nncf.common.engine import Engine @@ -26,13 +27,20 @@ class NNCFGraphFactory: @staticmethod - def create(model: TModel) -> NNCFGraph: + def create( + model: TModel, input_args: Optional[Tuple[Any, ...]] = None, input_kwargs: Optional[Dict[str, Any]] = None + ) -> NNCFGraph: """ Factory method to create backend-specific NNCFGraph instance based on the input model. :param model: backend-specific model instance :return: backend-specific NNCFGraph instance """ + if input_args is None: + input_args = () + if input_kwargs is None: + input_kwargs = {} + model_backend = get_backend(model) if model_backend == BackendType.ONNX: from nncf.onnx.graph.nncf_graph_builder import GraphConverter @@ -47,7 +55,13 @@ def create(model: TModel) -> NNCFGraph: return GraphConverter.create_nncf_graph(model) if model_backend == BackendType.TORCH: - return model.nncf.get_graph() + if os.getenv("NNCF_EXPERIMENTAL_TORCH_TRACING") is None: + return model.nncf.get_graph() + else: + from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph + + return build_nncf_graph(model, *input_args, **input_kwargs) + raise nncf.UnsupportedBackendError( "Cannot create backend-specific graph because {} is not supported!".format(model_backend.value) ) diff --git a/nncf/common/model.py b/nncf/common/model.py new file mode 100644 index 00000000000..af29c876c11 --- /dev/null +++ b/nncf/common/model.py @@ -0,0 +1,67 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Optional, TypeVar + +from nncf.common.factory import NNCFGraphFactory +from nncf.common.graph.graph import NNCFGraph + +TModel = TypeVar("TModel") + + +class StateAttributes: + """ + The state attributes. + """ + + EXAMPLE_INPUT_ARGS = "example_input_args" + EXAMPLE_INPUT_KWARGS = "example_input_kwargs" + + +class ModelWrapper: + """ + A wrapper class for the original model. + + :param _model: The original model to be wrapped. + :param _graph: The graph representation of the model. + :param state: The storage of the model state. + """ + + def __init__( + self, model: TModel, graph: Optional[NNCFGraph] = None, state: Optional[Dict[str, Any]] = None + ) -> None: + self._model = model + self._graph = graph + self.state = state if state is not None else {} + + @property + def model(self) -> TModel: + """ + Retrieves the original model. + """ + return self._model + + @property + def graph(self) -> NNCFGraph: + """ + Returns the NNCFGraph representation of the model. + + If the graph has not been created yet, it will be created using the model, + example input arguments, and example input keyword arguments stored in the state. + """ + if self._graph is None: + self._graph = NNCFGraphFactory.create( + model=self.model, + input_args=self.state.get(StateAttributes.EXAMPLE_INPUT_ARGS), + input_kwargs=self.state.get(StateAttributes.EXAMPLE_INPUT_KWARGS), + ) + return self._graph diff --git a/nncf/onnx/quantization/quantize_model.py b/nncf/onnx/quantization/quantize_model.py index 7a4665d1a0c..05f709939ed 100644 --- a/nncf/onnx/quantization/quantize_model.py +++ b/nncf/onnx/quantization/quantize_model.py @@ -15,10 +15,10 @@ import nncf from nncf.common.logging.logger import nncf_logger +from nncf.common.model import ModelWrapper from nncf.common.quantization.structs import QuantizationPreset from nncf.data import Dataset from nncf.onnx.graph.metatypes.groups import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS -from nncf.onnx.graph.nncf_graph_builder import GraphConverter from nncf.parameters import DropType from nncf.parameters import ModelType from nncf.parameters import QuantizationMode @@ -78,11 +78,13 @@ def quantize_impl( advanced_parameters=advanced_parameters, ) - graph = GraphConverter.create_nncf_graph(model) - warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS) - quantized_model = quantization_algorithm.apply(model, graph, dataset=calibration_dataset) + model_wrapper = ModelWrapper(model) + warning_model_no_batchwise_support( + model_wrapper.graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS + ) + quantized_model = quantization_algorithm.apply(model_wrapper, dataset=calibration_dataset) - return quantized_model + return quantized_model.model def quantize_with_accuracy_control_impl( diff --git a/nncf/openvino/quantization/quantize_ifmodel.py b/nncf/openvino/quantization/quantize_ifmodel.py index 07d22171a17..3fa652bef21 100644 --- a/nncf/openvino/quantization/quantize_ifmodel.py +++ b/nncf/openvino/quantization/quantize_ifmodel.py @@ -25,6 +25,7 @@ from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.logging import nncf_logger from nncf.common.logging.track_progress import track +from nncf.common.model import ModelWrapper from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.openvino.graph.metatypes.openvino_metatypes import OVIfMetatype from nncf.openvino.graph.model_utils import remove_friendly_name_duplicates @@ -155,7 +156,9 @@ def apply_algorithm_if_bodies( """ nncf_logger.info(f"Iteration [{current_model_num}/{len(graphs)}] ...") parent_graph = graphs[graph_id] - quantized_model = algorithm.apply(parent_model, parent_graph, parent_statistic_points, parent_dataset) + quantized_model = algorithm.apply( + ModelWrapper(parent_model, parent_graph), parent_statistic_points, parent_dataset + ).model if get_number_if_op(parent_model) == 0: return quantized_model, current_model_num model_transformer_fp32 = factory.ModelTransformerFactory.create(parent_model) diff --git a/nncf/openvino/quantization/quantize_model.py b/nncf/openvino/quantization/quantize_model.py index 46db1c50cca..24b96e1e6ae 100644 --- a/nncf/openvino/quantization/quantize_model.py +++ b/nncf/openvino/quantization/quantize_model.py @@ -19,13 +19,13 @@ from nncf.common.factory import NNCFGraphFactory from nncf.common.factory import StatisticsAggregatorFactory from nncf.common.logging import nncf_logger +from nncf.common.model import ModelWrapper from nncf.common.quantization.structs import QuantizationPreset from nncf.data import Dataset from nncf.openvino.graph.metatypes.groups import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS from nncf.openvino.graph.metatypes.openvino_metatypes import OVIfMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype from nncf.openvino.graph.model_utils import remove_friendly_name_duplicates -from nncf.openvino.graph.nncf_graph_builder import GraphConverter from nncf.openvino.graph.node_utils import get_number_if_op from nncf.openvino.quantization.backend_parameters import BackendParameters from nncf.openvino.quantization.backend_parameters import is_weight_compression_needed @@ -166,9 +166,11 @@ def native_quantize_impl( ignored_scope=ignored_scope, advanced_parameters=advanced_parameters, ) - graph = GraphConverter.create_nncf_graph(model) - warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS) - quantized_model = quantization_algorithm.apply(model, graph, dataset=calibration_dataset) + model_wrapper = ModelWrapper(model) + warning_model_no_batchwise_support( + model_wrapper.graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS + ) + quantized_model = quantization_algorithm.apply(model_wrapper, dataset=calibration_dataset).model if is_weight_compression_needed(advanced_parameters): compress_quantize_weights_transformation(quantized_model) diff --git a/nncf/quantization/algorithms/algorithm.py b/nncf/quantization/algorithms/algorithm.py index befe0a82f9d..f5fe6896971 100644 --- a/nncf/quantization/algorithms/algorithm.py +++ b/nncf/quantization/algorithms/algorithm.py @@ -14,7 +14,7 @@ from typing import List, Optional, TypeVar from nncf import Dataset -from nncf.common.graph.graph import NNCFGraph +from nncf.common.model import ModelWrapper from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -38,8 +38,7 @@ def available_backends(self) -> List[BackendType]: @abstractmethod def apply( self, - model: TModel, - graph: NNCFGraph, + model: ModelWrapper, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, ) -> TModel: @@ -47,18 +46,16 @@ def apply( Applies the algorithm to the model. :param model: Model for applying algorithm. - :param graph: Model graph. :param statistic_points: Statistic points with collected statistics values. :param dataset: A representative dataset for the calibration process. :return: A resulting model. """ @abstractmethod - def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: + def get_statistic_points(self, model: ModelWrapper) -> StatisticPointsContainer: """ Returns statistic points, for which StatisticsCollector should collect statistics. :param model: Model for statistics collection. - :param graph: Model graph. :return: Statistic points, for which StatisticsCollector should collect statistics. """ diff --git a/nncf/quantization/algorithms/bias_correction/algorithm.py b/nncf/quantization/algorithms/bias_correction/algorithm.py index 63db2ee0adf..fdfba42ace5 100644 --- a/nncf/quantization/algorithms/bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/bias_correction/algorithm.py @@ -25,6 +25,7 @@ from nncf.common.graph.transformations.commands import TransformationCommand from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.logging.track_progress import track +from nncf.common.model import ModelWrapper from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -133,11 +134,11 @@ def _set_backend_entity(self, model: TModel) -> None: def apply( self, - model: TModel, - graph: NNCFGraph, + model_wrapper: ModelWrapper, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, ) -> TModel: + model = model_wrapper.model self._set_backend_entity(model) main_transformations_layout = TransformationLayout() main_model_transformer = ModelTransformerFactory.create(model) @@ -553,8 +554,10 @@ def output_filter_func(point): output_fp.extend(tensor_collector.get_statistics().mean_values) return output_fp - def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: - self._set_backend_entity(model) + def get_statistic_points(self, model_wrapper: ModelWrapper) -> StatisticPointsContainer: + model = model_wrapper.model + graph = model_wrapper.graph + self._set_backend_entity(model_wrapper.model) statistic_container = StatisticPointsContainer() nodes_with_bias = [ diff --git a/nncf/quantization/algorithms/channel_alignment/algorithm.py b/nncf/quantization/algorithms/channel_alignment/algorithm.py index b30749b6d2c..2a0fb0f4a1b 100644 --- a/nncf/quantization/algorithms/channel_alignment/algorithm.py +++ b/nncf/quantization/algorithms/channel_alignment/algorithm.py @@ -25,6 +25,7 @@ from nncf.common.graph.utils import get_reduction_axes from nncf.common.logging import nncf_logger from nncf.common.logging.track_progress import track +from nncf.common.model import ModelWrapper from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -381,8 +382,9 @@ def _get_target_point_and_node_in(self, conv_in, add_in) -> Tuple[TargetPoint, N node_in, ) - def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: - self._set_backend_entity(model) + def get_statistic_points(self, model_wrapper: ModelWrapper) -> StatisticPointsContainer: + self._set_backend_entity(model_wrapper.model) + graph = model_wrapper.graph statistic_container = StatisticPointsContainer() for conv_in, add_in, _ in self._get_node_pairs(graph): diff --git a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py index 3d104cad3c9..40c8a87a364 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py @@ -16,13 +16,13 @@ from nncf import Dataset from nncf.common.factory import EngineFactory from nncf.common.factory import ModelTransformerFactory -from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.model_transformer import ModelTransformer from nncf.common.graph.transformations.commands import TargetPoint from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.logging import nncf_logger from nncf.common.logging.track_progress import track +from nncf.common.model import ModelWrapper from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -129,11 +129,12 @@ def _set_backend_entity(self, model: TModel) -> None: def apply( self, - model: TModel, - graph: NNCFGraph, + model_wrapper: ModelWrapper, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: + ) -> ModelWrapper: + model = model_wrapper.model + graph = model_wrapper.graph self._set_backend_entity(model) model_transformer = ModelTransformerFactory.create(model) @@ -207,7 +208,9 @@ def apply( transformation_layout.register(self._backend_entity.create_bias_correction_command(node, bias_value, graph)) transformed_model = model_transformer.transform(transformation_layout) - return transformed_model + return ModelWrapper( + model=transformed_model, graph=graph, state=model_wrapper.state # BC dows not changed model's graph + ) @staticmethod def _get_bias_shift_magnitude(current_bias_value: Tensor, updated_bias_value: Tensor) -> Tensor: @@ -345,7 +348,9 @@ def _get_bias_shift( bias_shift = fns.stack(output_fp) - q_outputs return bias_shift - def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: + def get_statistic_points(self, model_wrapper: ModelWrapper) -> StatisticPointsContainer: + model = model_wrapper.model + graph = model_wrapper.graph self._set_backend_entity(model) nodes_with_bias = [ node for node in graph.get_all_nodes() if self._backend_entity.is_node_with_bias(node, graph) diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index dea9211b734..b802ea4536e 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -31,6 +31,7 @@ from nncf.common.hardware.config import get_hw_config_type from nncf.common.insertion_point_graph import InsertionPointGraph from nncf.common.logging import nncf_logger +from nncf.common.model import ModelWrapper from nncf.common.quantization.config_assignment import assign_qconfig_lists_to_modules from nncf.common.quantization.initialization.range import RangeInitCollectorParams from nncf.common.quantization.quantizer_propagation.solver import QuantizerPropagationRule @@ -889,14 +890,16 @@ def _get_quantization_points_overflow_fix( def apply( self, - model: TModel, - graph: NNCFGraph, + model_wrapper: ModelWrapper, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: + ) -> ModelWrapper: transformation_layout = TransformationLayout() - model_transformer = ModelTransformerFactory.create(model) - quantization_target_points, unified_scale_groups = self._get_quantization_target_points(model, graph) + model_transformer = ModelTransformerFactory.create(model_wrapper.model) + graph = model_wrapper.graph + quantization_target_points, unified_scale_groups = self._get_quantization_target_points( + model_wrapper.model, graph + ) quantization_points_overflow_fix = self._get_quantization_points_overflow_fix( self._overflow_fix, quantization_target_points, graph ) @@ -987,12 +990,12 @@ def filter_func(point: StatisticPoint) -> bool: if not transformation_layout.transformations: nncf_logger.info("The model has no operations to apply quantization.") quantized_model = model_transformer.transform(transformation_layout) - return quantized_model + return ModelWrapper(quantized_model, state=model_wrapper.state) - def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: - self._set_backend_entity(model) + def get_statistic_points(self, model_wrapper: ModelWrapper) -> StatisticPointsContainer: + self._set_backend_entity(model_wrapper.model) self._reset_cache() - quantization_target_points, _ = self._get_quantization_target_points(model, graph) + quantization_target_points, _ = self._get_quantization_target_points(model_wrapper.model, model_wrapper.graph) output = StatisticPointsContainer() for quantization_target_point, qconfig in quantization_target_points.items(): nncf_logger.debug( @@ -1000,7 +1003,7 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin f" with type {quantization_target_point.type} for statistics collection" ) stat_collector = self._get_stat_collector( - graph, quantization_target_point, qconfig, self._batchwise_statistics + model_wrapper.graph, quantization_target_point, qconfig, self._batchwise_statistics ) output.add_statistic_point( StatisticPoint( diff --git a/nncf/quantization/algorithms/pipeline.py b/nncf/quantization/algorithms/pipeline.py index cd615258553..ae9d4276a3c 100644 --- a/nncf/quantization/algorithms/pipeline.py +++ b/nncf/quantization/algorithms/pipeline.py @@ -11,10 +11,9 @@ from typing import Dict, List, Optional, TypeVar, Union -from nncf.common.factory import NNCFGraphFactory from nncf.common.factory import StatisticsAggregatorFactory -from nncf.common.graph.graph import NNCFGraph from nncf.common.logging import nncf_logger +from nncf.common.model import ModelWrapper from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend @@ -27,8 +26,7 @@ def collect_statistics( containers: Union[StatisticPointsContainer, List[StatisticPointsContainer]], - model: TModel, - graph: NNCFGraph, + model_state: ModelWrapper, dataset: Dataset, ) -> StatisticPointsContainer: """ @@ -36,17 +34,17 @@ def collect_statistics( :param statistic_points: Statistic points that need to be collected. :param model: A model. - :param graph: A graph assosiated with a model. + :param graph: A graph associated with a model. :param dataset: A dataset. :return: Collected statistics. """ if not isinstance(containers, list): containers = [containers] - statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset) + statistics_aggregator = StatisticsAggregatorFactory.create(model_state.model, dataset) for container in containers: statistics_aggregator.register_statistic_points(container) - statistics_aggregator.collect_statistics(model, graph) + statistics_aggregator.collect_statistics(model_state.model, model_state.graph) return statistics_aggregator.statistic_points @@ -96,8 +94,7 @@ def run_step( self, step_index: int, step_statistics: StatisticPointsContainer, - model: TModel, - graph: NNCFGraph, + model_wrapper: ModelWrapper, ) -> TModel: """ Executes a provided pipeline step on the provided model. @@ -105,36 +102,31 @@ def run_step( :param step_index: Zero-based index of the pipeline step that should be executed :param step_statistics: Statistics required to execute a pipeline step. :param model: A model to which a pipeline step will be applied. - :param graph: A graph assosiated with a model. + :param graph: A graph associated with a model. :return: The updated model after executing the pipeline step. """ - current_model = model - current_graph = graph + current_model = model_wrapper - pipeline_steps = self._remove_unsupported_algorithms(get_backend(model)) + pipeline_steps = self._remove_unsupported_algorithms(get_backend(model_wrapper.model)) pipeline_step = pipeline_steps[step_index] - for algorithm in pipeline_step[:-1]: - current_model = algorithm.apply(current_model, current_graph, step_statistics) - current_graph = NNCFGraphFactory.create(current_model) - current_model = pipeline_step[-1].apply(current_model, current_graph, step_statistics) - + for algorithm in pipeline_step: + current_model = algorithm.apply(current_model, step_statistics) return current_model def run_from_step( self, - model: TModel, + model: ModelWrapper, dataset: Dataset, - graph: Optional[NNCFGraph] = None, start_step_index: int = 0, step_index_to_statistics: Optional[Dict[int, StatisticPointsContainer]] = None, - ) -> TModel: + ) -> ModelWrapper: """ Executes the pipeline from the specified pipeline step to the end. :param model: This is the model after the (start_step_index - 1)-th pipeline step, or the initial model if start_step_index is 0. :param dataset: A dataset that holds the data items for pipeline steps. - :param graph: A graph assosiated with a model. + :param graph: A graph associated with a model. :param start_step_index: Zero-based pipeline step index from which the pipeline should be executed. :param step_index_to_statistics: A mapping from pipeline step index to statistics @@ -142,47 +134,38 @@ def run_from_step( :return: The updated model after executing the pipeline from the specified pipeline step to the end. """ - pipeline_steps = self._remove_unsupported_algorithms(get_backend(model)) + pipeline_steps = self._remove_unsupported_algorithms(get_backend(model.model)) if step_index_to_statistics is None: step_index_to_statistics = {} # The `step_model` and `step_graph` entities are required to execute `step_index`-th pipeline step step_model = model - step_graph = graph for step_index in range(start_step_index, len(pipeline_steps)): - # Create graph required to run current pipeline step - if step_graph is None: - step_graph = NNCFGraphFactory.create(step_model) - # Collect statistics required to run current pipeline step step_statistics = step_index_to_statistics.get(step_index) if step_statistics is None: - statistic_points = self.get_statistic_points_for_step(step_index, step_model, step_graph) - step_statistics = collect_statistics(statistic_points, step_model, step_graph, dataset) + statistic_points = self.get_statistic_points_for_step(step_index, step_model) + step_statistics = collect_statistics(statistic_points, step_model, dataset) # Run current pipeline step - step_model = self.run_step(step_index, step_statistics, step_model, step_graph) - - step_graph = None # We should rebuild the graph for the next pipeline step + step_model = self.run_step(step_index, step_statistics, step_model) return step_model - def get_statistic_points_for_step( - self, step_index: int, model: TModel, graph: NNCFGraph - ) -> StatisticPointsContainer: + def get_statistic_points_for_step(self, step_index: int, model_wrapper: ModelWrapper) -> StatisticPointsContainer: """ Returns statistics that should be collected to execute `step_index`-th pipeline step. :param step_index: Zero-based index of the pipeline step. :param model: A model. - :param graph: A graph assosiated with a model. + :param graph: A graph associated with a model. :return: Statistics that should be collected to execute `step_index`-th pipeline step. """ container = StatisticPointsContainer() - pipeline_steps = self._remove_unsupported_algorithms(get_backend(model)) + pipeline_steps = self._remove_unsupported_algorithms(get_backend(model_wrapper.model)) pipeline_step = pipeline_steps[step_index] for algorithm in pipeline_step: - for statistic_points in algorithm.get_statistic_points(model, graph).values(): + for statistic_points in algorithm.get_statistic_points(model_wrapper).values(): for statistic_point in statistic_points: container.add_statistic_point(statistic_point) diff --git a/nncf/quantization/algorithms/post_training/algorithm.py b/nncf/quantization/algorithms/post_training/algorithm.py index 862dc5d5037..9fd02014770 100644 --- a/nncf/quantization/algorithms/post_training/algorithm.py +++ b/nncf/quantization/algorithms/post_training/algorithm.py @@ -95,7 +95,6 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin def apply( self, model: TModel, - graph: NNCFGraph, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, ) -> TModel: @@ -109,4 +108,4 @@ def apply( if statistic_points: step_index_to_statistics = {0: statistic_points} - return self._pipeline.run_from_step(model, dataset, graph, 0, step_index_to_statistics) + return self._pipeline.run_from_step(model, dataset, 0, step_index_to_statistics) diff --git a/nncf/quantization/algorithms/smooth_quant/algorithm.py b/nncf/quantization/algorithms/smooth_quant/algorithm.py index 83aefc6709a..77ab1d8ed9b 100644 --- a/nncf/quantization/algorithms/smooth_quant/algorithm.py +++ b/nncf/quantization/algorithms/smooth_quant/algorithm.py @@ -24,6 +24,7 @@ from nncf.common.graph.utils import get_reduction_axes from nncf.common.logging import nncf_logger from nncf.common.logging.track_progress import track +from nncf.common.model import ModelWrapper from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -98,11 +99,12 @@ def _set_backend_entity(self, model: TModel) -> None: def apply( self, - model: TModel, - graph: NNCFGraph, + model_wrapper: ModelWrapper, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: + ) -> ModelWrapper: + model = model_wrapper.model + graph = model_wrapper.graph self._set_backend_entity(model) alpha_map = self._get_alpha_map() @@ -176,7 +178,7 @@ def apply( transformation_layout.register(scale_insertion_command) transformed_model = model_transformer.transform(transformation_layout) - return transformed_model + return ModelWrapper(model=transformed_model, state=model_wrapper.state) @staticmethod def _calculate_scale_and_ratio( @@ -245,7 +247,10 @@ def _get_statistics_for_node( statistics_for_node.append(statistic) return statistics_for_node - def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: + def get_statistic_points(self, model_wrapper: ModelWrapper) -> StatisticPointsContainer: + model = model_wrapper.model + graph = model_wrapper.graph + statistic_container = StatisticPointsContainer() self._set_backend_entity(model) diff --git a/nncf/torch/quantization/quantize_model.py b/nncf/torch/quantization/quantize_model.py index 23cb451f5fe..fb3ee792730 100644 --- a/nncf/torch/quantization/quantize_model.py +++ b/nncf/torch/quantization/quantize_model.py @@ -16,6 +16,7 @@ import nncf from nncf.common.factory import NNCFGraphFactory +from nncf.common.model import ModelWrapper from nncf.common.quantization.structs import QuantizationPreset from nncf.data import Dataset from nncf.parameters import BackupMode @@ -32,6 +33,7 @@ from nncf.scopes import IgnoredScope from nncf.torch.graph.operator_metatypes import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS from nncf.torch.model_creation import wrap_model +from nncf.torch.nncf_network import NNCFNetwork DEFAULT_RANGE_TYPE = "mean_min_max" @@ -72,12 +74,13 @@ def quantize_impl( ignored_scope=ignored_scope, advanced_parameters=advanced_parameters, ) - graph = nncf_network.nncf.get_graph() - warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS) - quantized_model = quantization_algorithm.apply(nncf_network, graph, dataset=calibration_dataset) + model_wrapper = ModelWrapper(nncf_network) + warning_model_no_batchwise_support( + model_wrapper.graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS + ) + quantized_model: NNCFNetwork = quantization_algorithm.apply(model_wrapper, dataset=calibration_dataset).model quantized_model.nncf.disable_dynamic_graph_building() - return quantized_model diff --git a/tests/cross_fw/test_templates/test_bias_correction.py b/tests/cross_fw/test_templates/test_bias_correction.py index 81b638eb900..6fe670c7496 100644 --- a/tests/cross_fw/test_templates/test_bias_correction.py +++ b/tests/cross_fw/test_templates/test_bias_correction.py @@ -15,6 +15,7 @@ import pytest from nncf.common.factory import NNCFGraphFactory +from nncf.common.model import ModelWrapper from nncf.data import Dataset from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.advanced_parameters import OverflowFix @@ -121,9 +122,8 @@ def quantized_test_model(self, tmpdir) -> TModel: dataset = Dataset(self.get_dataset(model_cls.INPUT_SIZE), self.get_transform_fn()) quantization_algorithm = self.get_quantization_algorithm(disable_bias_correction=True) - graph = NNCFGraphFactory.create(model) - quantized_model = quantization_algorithm.apply(model, graph, dataset=dataset) - modified_model = self.remove_fq_from_inputs(quantized_model) + quantized_model = quantization_algorithm.apply(ModelWrapper(model), dataset=dataset) + modified_model = self.remove_fq_from_inputs(quantized_model.model) return modified_model @pytest.mark.parametrize( @@ -150,8 +150,7 @@ def test_update_bias(self, model_cls, ref_biases, tmpdir): dataset = Dataset(self.get_dataset(model_cls.INPUT_SIZE), self.get_transform_fn()) quantization_algorithm = self.get_quantization_algorithm() - graph = NNCFGraphFactory.create(model) - quantized_model = quantization_algorithm.apply(model, graph, dataset=dataset) + quantized_model = quantization_algorithm.apply(ModelWrapper(model), dataset=dataset) mapped_ref_biases = self.map_references(ref_biases, model_cls) self.check_bias(quantized_model, mapped_ref_biases) @@ -171,10 +170,9 @@ def test__get_subgraph_data_for_node(self, quantized_test_model, layer_name, ref def test_verify_collected_stat_inputs_map(self, model_cls, ref_stat_inputs_map, tmpdir): model = self.backend_specific_model(model_cls(), tmpdir) - graph = NNCFGraphFactory.create(model) bc_algo = self.get_bias_correction_algorithm() - bc_algo.get_statistic_points(model, graph) + bc_algo.get_statistic_points(ModelWrapper(model)) collected_stat_inputs_map = getattr(bc_algo, "_collected_stat_inputs_map") assert collected_stat_inputs_map == ref_stat_inputs_map diff --git a/tests/cross_fw/test_templates/test_channel_alignment.py b/tests/cross_fw/test_templates/test_channel_alignment.py index 7995f91961c..373fe356802 100644 --- a/tests/cross_fw/test_templates/test_channel_alignment.py +++ b/tests/cross_fw/test_templates/test_channel_alignment.py @@ -22,6 +22,7 @@ from nncf.common.graph.model_transformer import ModelTransformer from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.commands import TransformationType +from nncf.common.model import ModelWrapper from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic @@ -511,7 +512,7 @@ class MockBackend(backend_cls): MockBackend.get_statistic_collector = mocker.MagicMock(return_value=ref_stat_collector) algorithm._backend_entity = MockBackend - statistic_container = algorithm.get_statistic_points(None, nncf_graph) + statistic_container = algorithm.get_statistic_points(ModelWrapper(None, nncf_graph)) backend_cls = self.get_backend_cls() target_node_name = "/Add_1_0" if num_biases else "/Conv_1_0" diff --git a/tests/cross_fw/test_templates/test_fast_bias_correction.py b/tests/cross_fw/test_templates/test_fast_bias_correction.py index 899be7d9a1a..22c91654c11 100644 --- a/tests/cross_fw/test_templates/test_fast_bias_correction.py +++ b/tests/cross_fw/test_templates/test_fast_bias_correction.py @@ -14,7 +14,7 @@ import pytest -from nncf.common.factory import NNCFGraphFactory +from nncf.common.model import ModelWrapper from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.advanced_parameters import OverflowFix from nncf.quantization.algorithms.fast_bias_correction.algorithm import FastBiasCorrection @@ -115,7 +115,6 @@ def test_update_bias(self, model_cls, ref_bias, tmpdir): dataset = get_static_dataset(model_cls.INPUT_SIZE, self.get_transform_fn(), self.fn_to_type) quantization_algorithm = self.get_quantization_algorithm() - graph = NNCFGraphFactory.create(model) - quantized_model = quantization_algorithm.apply(model, graph, dataset=dataset) + quantized_model = quantization_algorithm.apply(ModelWrapper(model), dataset=dataset) - self.check_bias(quantized_model, ref_bias) + self.check_bias(quantized_model.model, ref_bias) diff --git a/tests/cross_fw/test_templates/test_ptq_params.py b/tests/cross_fw/test_templates/test_ptq_params.py index eacf57652e7..d8989a36d94 100644 --- a/tests/cross_fw/test_templates/test_ptq_params.py +++ b/tests/cross_fw/test_templates/test_ptq_params.py @@ -21,6 +21,7 @@ from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.operator_metatypes import OutputNoopMetatype from nncf.common.graph.transformations.commands import TargetType +from nncf.common.model import ModelWrapper from nncf.common.quantization.structs import QuantizationPreset from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode from nncf.common.quantization.structs import QuantizerConfig @@ -203,7 +204,7 @@ def test_range_estimator_per_tensor(self, test_params, range_estimator_params): assert min_max_algo._range_estimator_params[QuantizerGroup.ACTIVATIONS] == range_estimator_params params = test_params["test_range_estimator_per_tensor"] - stat_points = min_max_algo.get_statistic_points(params["model"], params["nncf_graph"]) + stat_points = min_max_algo.get_statistic_points(ModelWrapper(params["model"], params["nncf_graph"])) assert len(stat_points) == params["stat_points_num"] for _, stat_point in stat_points.items(): @@ -374,7 +375,7 @@ def test_unified_scales_command_creation(self, mocker): Tensor(self.get_backend_tensor(idx - 1)), Tensor(self.get_backend_tensor(idx + 2)) ) stats.add_statistic_point(StatisticPoint(tp, tc, algo._algorithm_key)) - algo.apply(model, model.nncf_graph, stats) + algo.apply(ModelWrapper(model, model.nncf_graph), stats) mock_transformer.transform.assert_called_once() layout = mock_transformer.transform.call_args.args[0] self.check_unified_scale_layout(layout, unified_scales_group) @@ -423,7 +424,5 @@ def test_empty_statistics(self, mode, mocker): "nncf.quantization.algorithms.min_max.algorithm.MinMaxQuantization._get_quantization_points_overflow_fix", return_value=mocker.MagicMock(), ) - with pytest.raises(nncf.InternalError) as exc_info: - algo.apply(None, None, stat_points) - - assert str(exc_info.value) == "Statistics were not collected for the node A" + with pytest.raises(nncf.InternalError, match="Statistics were not collected for the node A"): + algo.apply(mocker.MagicMock(), stat_points) diff --git a/tests/cross_fw/test_templates/test_smooth_quant.py b/tests/cross_fw/test_templates/test_smooth_quant.py index f4ea260c14e..69ab9096758 100644 --- a/tests/cross_fw/test_templates/test_smooth_quant.py +++ b/tests/cross_fw/test_templates/test_smooth_quant.py @@ -19,6 +19,7 @@ from nncf.common.factory import NNCFGraphFactory from nncf.common.factory import StatisticsAggregatorFactory from nncf.common.graph.graph import NNCFNode +from nncf.common.model import ModelWrapper from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator from nncf.parameters import ModelType @@ -165,8 +166,7 @@ def test_smooth_quant_algo(self, model_cls, reference_values, tmpdir): dataset = get_static_dataset(model_cls.INPUT_SIZE, self.get_transform_fn(), self.fn_to_type) quantization_algorithm = self.get_quantization_algorithm(self.get_ignored_scope(model_cls)) - graph = NNCFGraphFactory.create(model) - quantized_model = quantization_algorithm.apply(model, graph, dataset=dataset) + quantized_model = quantization_algorithm.apply(ModelWrapper(model), dataset=dataset).model self.check_scales(quantized_model, reference_values, model_cls) @@ -246,7 +246,7 @@ def test_empty_stats(self, mocker, tmpdir): graph = NNCFGraphFactory.create(model) algo = SmoothQuant(subset_size=1, inplace_statistics=False) - algo_statistic_points = algo.get_statistic_points(model, graph) + algo_statistic_points = algo.get_statistic_points(ModelWrapper(model)) statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset) statistics_aggregator.register_statistic_points(algo_statistic_points) statistics_aggregator.collect_statistics(model, graph) @@ -260,7 +260,7 @@ def test_empty_stats(self, mocker, tmpdir): mocked_transformer = mocker.MagicMock() mocker.patch("nncf.common.factory.ModelTransformerFactory.create", return_value=mocked_transformer) - algo.apply(model, graph, algo_statistic_points) + algo.apply(ModelWrapper(model), algo_statistic_points) mocked_transformer.transform.assert_called_once() arg = mocked_transformer.transform.call_args.args[0] diff --git a/tests/onnx/quantization/common.py b/tests/onnx/quantization/common.py index 18f36b29ee4..48bf787f2d1 100644 --- a/tests/onnx/quantization/common.py +++ b/tests/onnx/quantization/common.py @@ -16,6 +16,7 @@ import onnx from nncf import Dataset +from nncf.common.model import ModelWrapper from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic from nncf.onnx.graph.nncf_graph_builder import GraphConverter from nncf.onnx.graph.onnx_helper import get_edge_dtype @@ -108,7 +109,6 @@ def min_max_quantize_model( ) -> onnx.ModelProto: if convert_model_opset: original_model = convert_opset_version(original_model) - graph = GraphConverter.create_nncf_graph(original_model) dataset = get_random_dataset_for_test(original_model, dataset_has_batch_size) quantization_params = {} if quantization_params is None else quantization_params @@ -123,8 +123,8 @@ def min_max_quantize_model( post_training_quantization = PostTrainingQuantization(subset_size=1, **quantization_params) - quantized_model = post_training_quantization.apply(original_model, graph, dataset=dataset) - return quantized_model + quantized_model = post_training_quantization.apply(ModelWrapper(original_model), dataset=dataset) + return quantized_model.model def ptq_quantize_model( diff --git a/tests/openvino/native/quantization/test_fq_params_calculation.py b/tests/openvino/native/quantization/test_fq_params_calculation.py index 5751a34f39b..e1f0d4f8793 100644 --- a/tests/openvino/native/quantization/test_fq_params_calculation.py +++ b/tests/openvino/native/quantization/test_fq_params_calculation.py @@ -15,6 +15,7 @@ import pytest import torch +from nncf.common.model import ModelWrapper from nncf.common.quantization.structs import QuantizationPreset from nncf.openvino.graph.nncf_graph_builder import GraphConverter from nncf.openvino.statistics.aggregator import OVStatisticsAggregator @@ -70,11 +71,11 @@ def quantize_model(ov_model, q_params): min_max_algo = MinMaxQuantization(subset_size=1, **q_params) statistics_aggregator = OVStatisticsAggregator(dataset) - statistic_points = min_max_algo.get_statistic_points(ov_model, graph) + statistic_points = min_max_algo.get_statistic_points(ModelWrapper(ov_model, graph)) statistics_aggregator.register_statistic_points(statistic_points) statistics_aggregator.collect_statistics(ov_model, graph) - quantized_model = min_max_algo.apply(ov_model, graph, statistics_aggregator.statistic_points) - return quantized_model + quantized_model = min_max_algo.apply(ModelWrapper(ov_model, graph), statistics_aggregator.statistic_points) + return quantized_model.model @pytest.fixture(params=[True, False], ids=["inplace", "out_of_place"], name="inplace_statistics") diff --git a/tests/openvino/native/quantization/test_graphs.py b/tests/openvino/native/quantization/test_graphs.py index 7dc3c94c081..2352a009f27 100644 --- a/tests/openvino/native/quantization/test_graphs.py +++ b/tests/openvino/native/quantization/test_graphs.py @@ -18,6 +18,7 @@ import pytest from nncf import Dataset +from nncf.common.model import ModelWrapper from nncf.common.quantization.structs import QuantizationPreset from nncf.openvino.graph.nncf_graph_builder import GraphConverter from nncf.openvino.quantization.quantize_model import quantize_impl @@ -137,10 +138,12 @@ def smooth_quant_model(ov_model: ov.Model, q_params: Dict, quantize=True): smooth_quant_algo = SmoothQuant(subset_size=1) statistics_aggregator = OVStatisticsAggregator(dataset) - statistic_points = smooth_quant_algo.get_statistic_points(ov_model, graph) + statistic_points = smooth_quant_algo.get_statistic_points(ModelWrapper(ov_model, graph)) statistics_aggregator.register_statistic_points(statistic_points) statistics_aggregator.collect_statistics(ov_model, graph) - modified_model = smooth_quant_algo.apply(ov_model, graph, statistics_aggregator.statistic_points) + modified_model = smooth_quant_algo.apply( + ModelWrapper(ov_model, graph), statistics_aggregator.statistic_points + ).model if quantize: modified_model = quantize_model(modified_model, q_params) diff --git a/tests/torch/ptq/test_calculation_quantizer_params.py b/tests/torch/ptq/test_calculation_quantizer_params.py index 556b5f9e387..00f82f0e538 100644 --- a/tests/torch/ptq/test_calculation_quantizer_params.py +++ b/tests/torch/ptq/test_calculation_quantizer_params.py @@ -20,6 +20,7 @@ from nncf import Dataset from nncf.common.graph.transformations.commands import TargetType +from nncf.common.model import ModelWrapper from nncf.common.quantization.structs import QuantizationPreset from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode from nncf.common.quantization.structs import QuantizerConfig @@ -314,16 +315,14 @@ def test_quantizer_parameters_export(tmp_path: Path, _seed): statistics_aggregator = PTStatisticsAggregator(dataset) nncf_network = wrap_model(model, torch.ones([1, 3, 32, 32]), True) - statistic_points = min_max_algo.get_statistic_points(nncf_network, nncf_network.nncf.get_graph()) + statistic_points = min_max_algo.get_statistic_points(ModelWrapper(nncf_network)) statistics_aggregator.register_statistic_points(statistic_points) statistics_aggregator.collect_statistics(model, nncf_network.nncf.get_graph()) - torch_quantized_model = min_max_algo.apply( - nncf_network, nncf_network.nncf.get_graph(), statistics_aggregator.statistic_points - ) + torch_quantized_model = min_max_algo.apply(ModelWrapper(nncf_network), statistics_aggregator.statistic_points) path = str(tmp_path / "torch_ptq_model.onnx") torch.onnx.export( - torch_quantized_model, + torch_quantized_model.model, input_data, path, export_params=True, diff --git a/tests/torch/ptq/test_fq_params_calculation.py b/tests/torch/ptq/test_fq_params_calculation.py index 6d71760cd33..9c2bbe861b8 100644 --- a/tests/torch/ptq/test_fq_params_calculation.py +++ b/tests/torch/ptq/test_fq_params_calculation.py @@ -16,6 +16,7 @@ import torch import nncf +from nncf.common.model import ModelWrapper from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters from nncf.quantization.advanced_parameters import OverflowFix @@ -58,8 +59,8 @@ def transform_fn(sample): original_model.eval() nncf_network = wrap_model(original_model, torch.ones([1, 1, 10, 10]), trace_parameters=True) - quantized_model = post_training_quantization.apply(nncf_network, nncf_network.nncf.get_graph(), dataset=dataset) - return quantized_model + quantized_model = post_training_quantization.apply(ModelWrapper(nncf_network), dataset=dataset) + return quantized_model.model def get_fq_nodes(model: NNCFNetwork) -> Dict[Scope, torch.nn.Module]: diff --git a/tests/torch/ptq/test_graphs.py b/tests/torch/ptq/test_graphs.py index eba35163c7c..be902f9daa5 100644 --- a/tests/torch/ptq/test_graphs.py +++ b/tests/torch/ptq/test_graphs.py @@ -16,6 +16,7 @@ import torch from nncf import Dataset +from nncf.common.model import ModelWrapper from nncf.parameters import ModelType from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters @@ -121,8 +122,7 @@ def test_min_max_classification_quantized_graphs(desc: ModelDesc, quantization_p quantization_algorithm = PostTrainingQuantization(**quantization_parameters) quantized_model = quantization_algorithm.apply( - nncf_network, - nncf_network.nncf.get_graph(), + ModelWrapper(nncf_network), dataset=Dataset([example_input]), - ) + ).model check_graph(quantized_model.nncf.get_graph(), desc.dot_filename(), graph_dir)