Skip to content

Commit

Permalink
Reduced the number of graph rebuilds (#2011)
Browse files Browse the repository at this point in the history
### Changes

Extended the signature of algorithm methods by NNCFGraph. This changes
shows quantization speed-up in 2.34x for
"hf-internal-testing/tiny-random-GPTNeoXForCausalLM" model from optimum.

### Reason for changes

Reduced the number of graph rebuilds

### Related tickets

ref: 113245

### Tests

N/A
  • Loading branch information
alexsu52 authored Aug 4, 2023
1 parent 39124f0 commit 215718f
Show file tree
Hide file tree
Showing 37 changed files with 254 additions and 207 deletions.
11 changes: 7 additions & 4 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from nncf.common.factory import EngineFactory
from nncf.common.factory import ModelTransformerFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
Expand All @@ -36,19 +37,20 @@ def __init__(self, dataset: Dataset):
self.stat_subset_size = None
self.statistic_points = StatisticPointsContainer()

def collect_statistics(self, model: TModel) -> None:
def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
"""
Collects statistics for registered StatisticPoints.
The statistics are stored in self.statistic_points.
:param model: backend-specific model instance
:param model: Backend-specific model instance.
:param graph: Model graph.
"""
if not self.statistic_points:
return

model_transformer = ModelTransformerFactory.create(model)

merged_statistics = self._get_merged_statistic_points(self.statistic_points, model)
merged_statistics = self._get_merged_statistic_points(self.statistic_points, model, graph)
transformation_layout = self._get_transformation_layout_extra_outputs(merged_statistics)
model_with_outputs = model_transformer.transform(transformation_layout)
engine = EngineFactory.create(model_with_outputs)
Expand Down Expand Up @@ -105,7 +107,7 @@ def _get_transformation_layout_extra_outputs(
@staticmethod
@abstractmethod
def _get_merged_statistic_points(
statistic_points: StatisticPointsContainer, model: TModel
statistic_points: StatisticPointsContainer, model: TModel, graph: NNCFGraph
) -> StatisticPointsContainer:
"""
Creates a new StatisticPointContainer that has no duplicated tensor collectors for one
Expand All @@ -115,6 +117,7 @@ def _get_merged_statistic_points(
:param statistic_points: Registered statistic points with possible tensor collectors duplicates.
:param model: Backend-specific target model.
:param graph: Model graph.
:return: Merged statistic points container bounded with given statistic point container.
"""

Expand Down
6 changes: 3 additions & 3 deletions nncf/experimental/torch/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def quantize_impl(
advanced_parameters=advanced_parameters,
)

quantized_model = quantization_algorithm.apply(nncf_network, dataset=calibration_dataset)

# TODO (asuslov): quantized_model = quantized_model.strip()
quantized_model = quantization_algorithm.apply(
nncf_network, nncf_network.nncf.get_graph(), dataset=calibration_dataset
)

quantized_model.nncf.disable_dynamic_graph_building()

Expand Down
7 changes: 3 additions & 4 deletions nncf/onnx/graph/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import onnx

from nncf.common.factory import ModelTransformerFactory
from nncf.common.factory import NNCFGraphFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXDequantizeLinearMetatype
Expand All @@ -22,17 +22,16 @@
from nncf.onnx.graph.transformations.commands import ONNXTargetPoint


def remove_fq_from_inputs(model: onnx.ModelProto) -> onnx.ModelProto:
def remove_fq_from_inputs(model: onnx.ModelProto, nncf_graph: NNCFGraph) -> onnx.ModelProto:
"""
This method removes the activation Quantizer nodes from the model.
It's needed for the further bias shift calculation that relates on quantized weights.
:param model: onnx.ModelProto instance.
:param nncf_graph: NNCFGraph instance.
:return: onnx.ModelProto instance without activation Quantizer nodes.
"""
transformation_layout = TransformationLayout()
nncf_graph = NNCFGraphFactory.create(model)

model_transformer = ModelTransformerFactory.create(model)

seen_nodes = []
Expand Down
4 changes: 3 additions & 1 deletion nncf/onnx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from nncf.common.logging.logger import nncf_logger
from nncf.common.quantization.structs import QuantizationPreset
from nncf.data import Dataset
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
Expand Down Expand Up @@ -65,6 +66,7 @@ def quantize_impl(
advanced_parameters=advanced_parameters,
)

quantized_model = quantization_algorithm.apply(model, dataset=calibration_dataset)
graph = GraphConverter.create_nncf_graph(model)
quantized_model = quantization_algorithm.apply(model, graph, dataset=calibration_dataset)

return quantized_model
11 changes: 5 additions & 6 deletions nncf/onnx/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import numpy as np
import onnx

from nncf.common.factory import NNCFGraphFactory
from nncf.common.factory import TModel
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.tensor_statistics.aggregator import StatisticsAggregator
Expand All @@ -28,12 +28,11 @@


class ONNXStatisticsAggregator(StatisticsAggregator):
def collect_statistics(self, model: onnx.ModelProto) -> None:
self._nncf_graph = NNCFGraphFactory.create(model)
self.input_edges_mapping = get_input_edges_mapping(self._nncf_graph)
def collect_statistics(self, model: onnx.ModelProto, graph: NNCFGraph) -> None:
self.input_edges_mapping = get_input_edges_mapping(graph)
self._onnx_graph = ONNXGraph(model)
self._registered_weights = set()
super().collect_statistics(model)
super().collect_statistics(model, graph)

def _register_statistics(
self, outputs: Dict[str, ONNXNNCFTensor], statistic_points: StatisticPointsContainer
Expand Down Expand Up @@ -71,7 +70,7 @@ def _get_transformation_layout_extra_outputs(

@staticmethod
def _get_merged_statistic_points(
statistic_points: StatisticPointsContainer, model: TModel
statistic_points: StatisticPointsContainer, model: TModel, graph: NNCFGraph
) -> StatisticPointsContainer:
# TODO: mirgate to experimental statistic collector and use common merging algorithm
return statistic_points
Expand Down
19 changes: 9 additions & 10 deletions nncf/openvino/graph/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import openvino.runtime as ov

from nncf.common.factory import ModelTransformerFactory
from nncf.common.factory import NNCFGraphFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.openvino.graph.metatypes.common import FAKE_QUANTIZE_OPERATIONS
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionBackpropDataMetatype
Expand All @@ -25,11 +25,12 @@
from nncf.openvino.graph.transformations.command_creation import OVCommandCreator


def insert_null_biases(model: ov.Model) -> ov.Model:
def insert_null_biases(model: ov.Model, graph: NNCFGraph) -> ov.Model:
"""
This method finds and inserts zero biases for the layers that should have it.
:param model: ov.Model instance.
:param graph: Model graph.
:return: Updated ov.Model instance with zero biases
"""
types_to_insert_bias = [
Expand All @@ -39,9 +40,8 @@ def insert_null_biases(model: ov.Model) -> ov.Model:
OVConvolutionBackpropDataMetatype,
OVGroupConvolutionBackpropDataMetatype,
]
nncf_graph = NNCFGraphFactory.create(model)
nodes_without_biases = nncf_graph.get_nodes_by_metatypes(types_to_insert_bias)
nodes_without_biases = [node for node in nodes_without_biases if not is_node_with_bias(node, nncf_graph)]
nodes_without_biases = graph.get_nodes_by_metatypes(types_to_insert_bias)
nodes_without_biases = [node for node in nodes_without_biases if not is_node_with_bias(node, graph)]
transformation_layout = TransformationLayout()
model_transformer = ModelTransformerFactory.create(model)
for node_without_bias in nodes_without_biases:
Expand All @@ -50,21 +50,20 @@ def insert_null_biases(model: ov.Model) -> ov.Model:
return model_transformer.transform(transformation_layout)


def remove_fq_from_inputs(model: ov.Model) -> ov.Model:
def remove_fq_from_inputs(model: ov.Model, graph: NNCFGraph) -> ov.Model:
"""
This method removes the activation Fake Quantize nodes from the model.
It's needed for the further bias shift calculation that relates on quantized weights.
:param model: ov.Model instance.
:param graph: NNCFGraph instance.
:return: ov.Model instance without activation Fake Quantize nodes.
"""
transformation_layout = TransformationLayout()
nncf_graph = NNCFGraphFactory.create(model)

model_transformer = ModelTransformerFactory.create(model)

seen_nodes = []
nodes_queue = deque(nncf_graph.get_input_nodes())
nodes_queue = deque(graph.get_input_nodes())
while nodes_queue:
current_node = nodes_queue.popleft()
current_node_name = current_node.node_name
Expand All @@ -76,6 +75,6 @@ def remove_fq_from_inputs(model: ov.Model) -> ov.Model:
if current_node.metatype in FAKE_QUANTIZE_OPERATIONS:
command = OVCommandCreator.create_command_to_remove_quantizer(current_node)
transformation_layout.register(command)
nodes_queue.extend(nncf_graph.get_next_nodes(current_node))
nodes_queue.extend(graph.get_next_nodes(current_node))

return model_transformer.transform(transformation_layout)
4 changes: 3 additions & 1 deletion nncf/openvino/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nncf.common.logging import nncf_logger
from nncf.common.quantization.structs import QuantizationPreset
from nncf.data import Dataset
from nncf.openvino.graph.nncf_graph_builder import GraphConverter
from nncf.openvino.quantization.backend_parameters import BackendParameters
from nncf.openvino.quantization.backend_parameters import is_weight_compression_needed
from nncf.parameters import DropType
Expand Down Expand Up @@ -111,7 +112,8 @@ def native_quantize_impl(
advanced_parameters=advanced_parameters,
)

quantized_model = quantization_algorithm.apply(model, dataset=calibration_dataset)
graph = GraphConverter.create_nncf_graph(model)
quantized_model = quantization_algorithm.apply(model, graph, dataset=calibration_dataset)

if is_weight_compression_needed(advanced_parameters):
compress_quantize_weights_transformation(quantized_model)
Expand Down
13 changes: 6 additions & 7 deletions nncf/openvino/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@
import numpy as np
import openvino.runtime as ov

from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.tensor_statistics.aggregator import StatisticsAggregator
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.experimental.common.tensor_statistics.collectors import MergedTensorCollector
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.openvino.graph.nncf_graph_builder import GraphConverter
from nncf.openvino.graph.transformations.commands import OVInplaceFnInsertionCommand
from nncf.openvino.graph.transformations.commands import OVOutputInsertionCommand
from nncf.openvino.tensor import OVNNCFTensor


class OVStatisticsAggregator(StatisticsAggregator):
def collect_statistics(self, model: ov.Model) -> None:
def collect_statistics(self, model: ov.Model, graph: NNCFGraph) -> None:
self._name_to_node_mapping = {op.get_friendly_name(): op for op in model.get_ops()}
super().collect_statistics(model)
super().collect_statistics(model, graph)

def _register_statistics(
self, outputs: Dict[str, OVNNCFTensor], statistic_points: StatisticPointsContainer
Expand Down Expand Up @@ -75,17 +75,16 @@ def _get_transformation_layout_extra_outputs(
@staticmethod
# TODO(dlyakhov) Move this to common part
def _get_merged_statistic_points(
statistic_points: StatisticPointsContainer, model: ov.Model
statistic_points: StatisticPointsContainer, model: ov.Model, graph: NNCFGraph
) -> StatisticPointsContainer:
nncf_graph = GraphConverter.create_nncf_graph(model)
merged_statistic_points = StatisticPointsContainer()
target_type_to_tensor_collector_map = defaultdict(lambda: defaultdict(list))
for target_node_name, _statistic_points in statistic_points.data.items():
for statistic_point in _statistic_points:
target_point = statistic_point.target_point
if target_point.type in [TargetType.PRE_LAYER_OPERATION, TargetType.OPERATION_WITH_WEIGHTS]:
node = nncf_graph.get_node_by_name(target_node_name)
target_input_edge = nncf_graph.get_input_edges(node)[target_point.port_id]
node = graph.get_node_by_name(target_node_name)
target_input_edge = graph.get_input_edges(node)[target_point.port_id]

target_type = TargetType.POST_LAYER_OPERATION
_target_node_name = target_input_edge.from_node.node_name
Expand Down
42 changes: 17 additions & 25 deletions nncf/quantization/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,13 @@
from typing import Dict, Optional, TypeVar

from nncf import Dataset
from nncf.common.graph.graph import NNCFGraph
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType

TModel = TypeVar("TModel")


class AlgorithmParameters(ABC):
"""
Base class for Post-Training algorithm parameters.
"""


class Algorithm(ABC):
"""
Base class for all Post-Training algorithms.
Expand All @@ -35,38 +30,35 @@ class Algorithm(ABC):
@abstractmethod
def available_backends(self) -> Dict[str, BackendType]:
"""
Returns dictionary of the available backends for the algorithm
Returns dictionary of the available backends for the algorithm.
:return: Dict of backends supported by the algorithm
:return: Dict of backends supported by the algorithm.
"""

@abstractmethod
def apply(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
"""
Checks that statistic point exists, sets model into transformer
and applies the algorithm to the model.
:param model: model for applying algorithm
:param engine: engine for the model execution
:param statistic_points: StatisticPointsContainer
:return: model after algorithm
"""
# TODO (asuslov): add validation statistic_points
return self._apply(model, statistic_points=statistic_points, dataset=dataset)

@abstractmethod
def _apply(
self, model: TModel, statistic_points: StatisticPointsContainer, dataset: Optional[Dataset] = None
) -> TModel:
"""
Applies the algorithm to the model.
:param model: Model for applying algorithm.
:param graph: Model graph.
:param statistic_points: Statistic points with collected statistics values.
:param dataset: A representative dataset for the calibration process.
:return: A resulting model.
"""

@abstractmethod
def get_statistic_points(self, model: TModel) -> StatisticPointsContainer:
def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
"""
Returns activation layers, for which StatisticsCollector should collect statistics.
Returns statistic points, for which StatisticsCollector should collect statistics.
:param model: Model for statististics collection.
:param graph: Model graph.
:retrun: Statistic points, for which StatisticsCollector should collect statistics.
"""
15 changes: 9 additions & 6 deletions nncf/quantization/algorithms/bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,21 @@ def _set_backend_entity(self, model: TModel) -> None:
"Cannot return backend-specific entity because {} is not supported!".format(model_backend)
)

def _apply(
def apply(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
self._set_backend_entity(model)
model = self._backend_entity.insert_null_biases(model)
model = self._backend_entity.insert_null_biases(model, graph)
main_transformations_layout = TransformationLayout()
main_model_transformer = ModelTransformerFactory.create(model)

model_copy = copy_model(model)
model_copy = self._backend_entity.remove_fq_from_inputs(model_copy)
graph_copy = NNCFGraphFactory.create(model_copy)
model_copy = self._backend_entity.remove_fq_from_inputs(model_copy, graph_copy)
nncf_graph = NNCFGraphFactory.create(model_copy)

nodes_with_bias = []
Expand Down Expand Up @@ -484,10 +486,11 @@ def output_filter_func(point):
output_fp.extend(tensor_collector.get_statistics().mean_values)
return np.array(output_fp)

def get_statistic_points(self, model: TModel) -> StatisticPointsContainer:
def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
self._set_backend_entity(model)
model_copy = self._backend_entity.remove_fq_from_inputs(copy_model(model))
model_copy = self._backend_entity.insert_null_biases(model_copy)
model_copy = self._backend_entity.remove_fq_from_inputs(copy_model(model), graph)
graph_copy = NNCFGraphFactory.create(model_copy)
model_copy = self._backend_entity.insert_null_biases(model_copy, graph_copy)
nncf_graph = NNCFGraphFactory.create(model_copy)
statistic_container = StatisticPointsContainer()

Expand Down
Loading

0 comments on commit 215718f

Please sign in to comment.