Skip to content

Commit

Permalink
init model_wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Dec 9, 2024
1 parent 6031ccc commit 948d72c
Show file tree
Hide file tree
Showing 25 changed files with 220 additions and 131 deletions.
20 changes: 17 additions & 3 deletions nncf/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TypeVar
import os
from typing import Any, Dict, Optional, Tuple, TypeVar

import nncf
from nncf.common.engine import Engine
Expand All @@ -26,13 +27,20 @@

class NNCFGraphFactory:
@staticmethod
def create(model: TModel) -> NNCFGraph:
def create(
model: TModel, input_args: Optional[Tuple[Any, ...]] = None, input_kwargs: Optional[Dict[str, Any]] = None
) -> NNCFGraph:
"""
Factory method to create backend-specific NNCFGraph instance based on the input model.
:param model: backend-specific model instance
:return: backend-specific NNCFGraph instance
"""
if input_args is None:
input_args = ()
if input_kwargs is None:
input_kwargs = {}

model_backend = get_backend(model)
if model_backend == BackendType.ONNX:
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
Expand All @@ -47,7 +55,13 @@ def create(model: TModel) -> NNCFGraph:

return GraphConverter.create_nncf_graph(model)
if model_backend == BackendType.TORCH:
return model.nncf.get_graph()
if os.getenv("NNCF_EXPERIMENTAL_TORCH_TRACING") is None:
return model.nncf.get_graph()
else:
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph

return build_nncf_graph(model, *input_args, **input_kwargs)

raise nncf.UnsupportedBackendError(
"Cannot create backend-specific graph because {} is not supported!".format(model_backend.value)
)
Expand Down
67 changes: 67 additions & 0 deletions nncf/common/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Any, Dict, Optional, TypeVar

from nncf.common.factory import NNCFGraphFactory
from nncf.common.graph.graph import NNCFGraph

TModel = TypeVar("TModel")


class StateAttributes:
"""
The state attributes.
"""

EXAMPLE_INPUT_ARGS = "example_input_args"
EXAMPLE_INPUT_KWARGS = "example_input_kwargs"


class ModelWrapper:
"""
A wrapper class for the original model.
:param _model: The original model to be wrapped.
:param _graph: The graph representation of the model.
:param state: The storage of the model state.
"""

def __init__(
self, model: TModel, graph: Optional[NNCFGraph] = None, state: Optional[Dict[str, Any]] = None
) -> None:
self._model = model
self._graph = graph
self.state = state if state is not None else {}

@property
def model(self) -> TModel:
"""
Retrieves the original model.
"""
return self._model

@property
def graph(self) -> NNCFGraph:
"""
Returns the NNCFGraph representation of the model.
If the graph has not been created yet, it will be created using the model,
example input arguments, and example input keyword arguments stored in the state.
"""
if self._graph is None:
self._graph = NNCFGraphFactory.create(
model=self.model,
input_args=self.state.get(StateAttributes.EXAMPLE_INPUT_ARGS),
input_kwargs=self.state.get(StateAttributes.EXAMPLE_INPUT_KWARGS),
)
return self._graph
12 changes: 7 additions & 5 deletions nncf/onnx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

import nncf
from nncf.common.logging.logger import nncf_logger
from nncf.common.model import ModelWrapper
from nncf.common.quantization.structs import QuantizationPreset
from nncf.data import Dataset
from nncf.onnx.graph.metatypes.groups import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
from nncf.parameters import DropType
from nncf.parameters import ModelType
from nncf.parameters import QuantizationMode
Expand Down Expand Up @@ -78,11 +78,13 @@ def quantize_impl(
advanced_parameters=advanced_parameters,
)

graph = GraphConverter.create_nncf_graph(model)
warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS)
quantized_model = quantization_algorithm.apply(model, graph, dataset=calibration_dataset)
model_wrapper = ModelWrapper(model)
warning_model_no_batchwise_support(
model_wrapper.graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS
)
quantized_model = quantization_algorithm.apply(model_wrapper, dataset=calibration_dataset)

return quantized_model
return quantized_model.model


def quantize_with_accuracy_control_impl(
Expand Down
5 changes: 4 additions & 1 deletion nncf/openvino/quantization/quantize_ifmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.model import ModelWrapper
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.openvino.graph.metatypes.openvino_metatypes import OVIfMetatype
from nncf.openvino.graph.model_utils import remove_friendly_name_duplicates
Expand Down Expand Up @@ -155,7 +156,9 @@ def apply_algorithm_if_bodies(
"""
nncf_logger.info(f"Iteration [{current_model_num}/{len(graphs)}] ...")
parent_graph = graphs[graph_id]
quantized_model = algorithm.apply(parent_model, parent_graph, parent_statistic_points, parent_dataset)
quantized_model = algorithm.apply(
ModelWrapper(parent_model, parent_graph), parent_statistic_points, parent_dataset
).model
if get_number_if_op(parent_model) == 0:
return quantized_model, current_model_num
model_transformer_fp32 = factory.ModelTransformerFactory.create(parent_model)
Expand Down
10 changes: 6 additions & 4 deletions nncf/openvino/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
from nncf.common.factory import NNCFGraphFactory
from nncf.common.factory import StatisticsAggregatorFactory
from nncf.common.logging import nncf_logger
from nncf.common.model import ModelWrapper
from nncf.common.quantization.structs import QuantizationPreset
from nncf.data import Dataset
from nncf.openvino.graph.metatypes.groups import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS
from nncf.openvino.graph.metatypes.openvino_metatypes import OVIfMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype
from nncf.openvino.graph.model_utils import remove_friendly_name_duplicates
from nncf.openvino.graph.nncf_graph_builder import GraphConverter
from nncf.openvino.graph.node_utils import get_number_if_op
from nncf.openvino.quantization.backend_parameters import BackendParameters
from nncf.openvino.quantization.backend_parameters import is_weight_compression_needed
Expand Down Expand Up @@ -166,9 +166,11 @@ def native_quantize_impl(
ignored_scope=ignored_scope,
advanced_parameters=advanced_parameters,
)
graph = GraphConverter.create_nncf_graph(model)
warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS)
quantized_model = quantization_algorithm.apply(model, graph, dataset=calibration_dataset)
model_wrapper = ModelWrapper(model)
warning_model_no_batchwise_support(
model_wrapper.graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS
)
quantized_model = quantization_algorithm.apply(model_wrapper, dataset=calibration_dataset).model

if is_weight_compression_needed(advanced_parameters):
compress_quantize_weights_transformation(quantized_model)
Expand Down
9 changes: 3 additions & 6 deletions nncf/quantization/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import List, Optional, TypeVar

from nncf import Dataset
from nncf.common.graph.graph import NNCFGraph
from nncf.common.model import ModelWrapper
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType

Expand All @@ -38,27 +38,24 @@ def available_backends(self) -> List[BackendType]:
@abstractmethod
def apply(
self,
model: TModel,
graph: NNCFGraph,
model: ModelWrapper,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
"""
Applies the algorithm to the model.
:param model: Model for applying algorithm.
:param graph: Model graph.
:param statistic_points: Statistic points with collected statistics values.
:param dataset: A representative dataset for the calibration process.
:return: A resulting model.
"""

@abstractmethod
def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
def get_statistic_points(self, model: ModelWrapper) -> StatisticPointsContainer:
"""
Returns statistic points, for which StatisticsCollector should collect statistics.
:param model: Model for statistics collection.
:param graph: Model graph.
:return: Statistic points, for which StatisticsCollector should collect statistics.
"""
11 changes: 7 additions & 4 deletions nncf/quantization/algorithms/bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging.track_progress import track
from nncf.common.model import ModelWrapper
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
Expand Down Expand Up @@ -133,11 +134,11 @@ def _set_backend_entity(self, model: TModel) -> None:

def apply(
self,
model: TModel,
graph: NNCFGraph,
model_wrapper: ModelWrapper,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
model = model_wrapper.model
self._set_backend_entity(model)
main_transformations_layout = TransformationLayout()
main_model_transformer = ModelTransformerFactory.create(model)
Expand Down Expand Up @@ -553,8 +554,10 @@ def output_filter_func(point):
output_fp.extend(tensor_collector.get_statistics().mean_values)
return output_fp

def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
self._set_backend_entity(model)
def get_statistic_points(self, model_wrapper: ModelWrapper) -> StatisticPointsContainer:
model = model_wrapper.model
graph = model_wrapper.graph
self._set_backend_entity(model_wrapper.model)
statistic_container = StatisticPointsContainer()

nodes_with_bias = [
Expand Down
6 changes: 4 additions & 2 deletions nncf/quantization/algorithms/channel_alignment/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nncf.common.graph.utils import get_reduction_axes
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.model import ModelWrapper
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
Expand Down Expand Up @@ -381,8 +382,9 @@ def _get_target_point_and_node_in(self, conv_in, add_in) -> Tuple[TargetPoint, N
node_in,
)

def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
self._set_backend_entity(model)
def get_statistic_points(self, model_wrapper: ModelWrapper) -> StatisticPointsContainer:
self._set_backend_entity(model_wrapper.model)
graph = model_wrapper.graph

statistic_container = StatisticPointsContainer()
for conv_in, add_in, _ in self._get_node_pairs(graph):
Expand Down
17 changes: 11 additions & 6 deletions nncf/quantization/algorithms/fast_bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
from nncf import Dataset
from nncf.common.factory import EngineFactory
from nncf.common.factory import ModelTransformerFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.model_transformer import ModelTransformer
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.model import ModelWrapper
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
Expand Down Expand Up @@ -129,11 +129,12 @@ def _set_backend_entity(self, model: TModel) -> None:

def apply(
self,
model: TModel,
graph: NNCFGraph,
model_wrapper: ModelWrapper,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
) -> ModelWrapper:
model = model_wrapper.model
graph = model_wrapper.graph
self._set_backend_entity(model)

model_transformer = ModelTransformerFactory.create(model)
Expand Down Expand Up @@ -207,7 +208,9 @@ def apply(
transformation_layout.register(self._backend_entity.create_bias_correction_command(node, bias_value, graph))
transformed_model = model_transformer.transform(transformation_layout)

return transformed_model
return ModelWrapper(
model=transformed_model, graph=graph, state=model_wrapper.state # BC dows not changed model's graph
)

@staticmethod
def _get_bias_shift_magnitude(current_bias_value: Tensor, updated_bias_value: Tensor) -> Tensor:
Expand Down Expand Up @@ -345,7 +348,9 @@ def _get_bias_shift(
bias_shift = fns.stack(output_fp) - q_outputs
return bias_shift

def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
def get_statistic_points(self, model_wrapper: ModelWrapper) -> StatisticPointsContainer:
model = model_wrapper.model
graph = model_wrapper.graph
self._set_backend_entity(model)
nodes_with_bias = [
node for node in graph.get_all_nodes() if self._backend_entity.is_node_with_bias(node, graph)
Expand Down
23 changes: 13 additions & 10 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from nncf.common.hardware.config import get_hw_config_type
from nncf.common.insertion_point_graph import InsertionPointGraph
from nncf.common.logging import nncf_logger
from nncf.common.model import ModelWrapper
from nncf.common.quantization.config_assignment import assign_qconfig_lists_to_modules
from nncf.common.quantization.initialization.range import RangeInitCollectorParams
from nncf.common.quantization.quantizer_propagation.solver import QuantizerPropagationRule
Expand Down Expand Up @@ -889,14 +890,16 @@ def _get_quantization_points_overflow_fix(

def apply(
self,
model: TModel,
graph: NNCFGraph,
model_wrapper: ModelWrapper,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
) -> ModelWrapper:
transformation_layout = TransformationLayout()
model_transformer = ModelTransformerFactory.create(model)
quantization_target_points, unified_scale_groups = self._get_quantization_target_points(model, graph)
model_transformer = ModelTransformerFactory.create(model_wrapper.model)
graph = model_wrapper.graph
quantization_target_points, unified_scale_groups = self._get_quantization_target_points(
model_wrapper.model, graph
)
quantization_points_overflow_fix = self._get_quantization_points_overflow_fix(
self._overflow_fix, quantization_target_points, graph
)
Expand Down Expand Up @@ -987,20 +990,20 @@ def filter_func(point: StatisticPoint) -> bool:
if not transformation_layout.transformations:
nncf_logger.info("The model has no operations to apply quantization.")
quantized_model = model_transformer.transform(transformation_layout)
return quantized_model
return ModelWrapper(quantized_model, state=model_wrapper.state)

def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
self._set_backend_entity(model)
def get_statistic_points(self, model_wrapper: ModelWrapper) -> StatisticPointsContainer:
self._set_backend_entity(model_wrapper.model)
self._reset_cache()
quantization_target_points, _ = self._get_quantization_target_points(model, graph)
quantization_target_points, _ = self._get_quantization_target_points(model_wrapper.model, model_wrapper.graph)
output = StatisticPointsContainer()
for quantization_target_point, qconfig in quantization_target_points.items():
nncf_logger.debug(
f"Adding target point {quantization_target_point.target_node_name}"
f" with type {quantization_target_point.type} for statistics collection"
)
stat_collector = self._get_stat_collector(
graph, quantization_target_point, qconfig, self._batchwise_statistics
model_wrapper.graph, quantization_target_point, qconfig, self._batchwise_statistics
)
output.add_statistic_point(
StatisticPoint(
Expand Down
Loading

0 comments on commit 948d72c

Please sign in to comment.