Skip to content

Commit

Permalink
Dynamic OV model builder (#3137)
Browse files Browse the repository at this point in the history
### Changes

- Added `ModelBuilder` class.
- Updated FastBC algorithm to utilize new approach.

### Reason for changes

- Algorithm speed up.

### Related tickets

- 122317

### Tests

- Added tests/openvino/native/test_model_builder.py.
- Conversion jobs were run using DLB scope.

Results (develop run - manual/post_training_quantization_performance/91,
92; branch - manual/post_training_quantization_performance/90)

Model | Backend | FBC time (develop, OV) | FBC time (develop, PT) | FBC
time (branch, OV) | Diff (develop - branch, OV)
-- | -- | -- | -- | -- | --
hf/bert-base-uncased | OV | 00:00:02 | - | 00:00:01 | **00:00:01**
torchvision/resnet18 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00
torchvision/mobilenet_v3_small_BC | OV | 00:00:01 | - | 00:00:01 |
00:00:00
torchvision/vit_b_16 | OV | 00:00:02 | - | 00:00:01 | **00:00:01**
torchvision/swin_v2_s | OV | 00:00:05 | - | 00:00:01 | **00:00:04**
timm/crossvit_9_240 | OV | 00:00:01 | 00:00:00 | **00:00:01** | 00:00:00
timm/darknet53 | OV | 00:00:01 | 00:00:00 | **00:00:01** | 00:00:00
timm/deit3_small_patch16_224 | OV | 00:00:01 | 00:00:00 | 00:00:00 |
**00:00:01**
timm/dla34 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00
timm/dpn68 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00
timm/efficientnet_b0 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00
timm/efficientnet_b0_BC | OV | 00:00:04 | - | 00:00:04 | 00:00:00
timm/efficientnet_lite0 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00
timm/hrnet_w18 | OV | 00:00:13 | 00:00:03 | **00:00:04** | **00:00:09**
timm/inception_resnet_v2 | OV | 00:00:08 | 00:00:03 | **00:00:04** |
**00:00:04**
timm/levit_128 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00
timm/mobilenetv2_050 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00
timm/mobilenetv2_050_BC | OV | 00:00:03 | - | 00:00:03 | 00:00:00
timm/mobilenetv3_small_050 | OV | 00:00:00 | 00:00:00 | 00:00:00 |
00:00:00
timm/mobilenetv3_small_050_BC | OV | 00:00:01 | - | 00:00:01 | 00:00:00
timm/regnetx_002 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00
timm/resnest14d | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00
timm/swin_base_patch4_window7_224 | OV | 00:00:03 | 00:00:00 |
**00:00:02** | **00:00:01**
timm/tf_inception_v3 | OV | 00:00:01 | 00:00:01 | 00:00:01 | 00:00:00
timm/vgg11 | OV | 00:00:01 | 00:00:00 | **00:00:01** | 00:00:00
timm/visformer_small | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00
timm/wide_resnet50_2 | OV | 00:00:01 | 00:00:00 | **00:00:01** |
00:00:00
  • Loading branch information
KodiaqQ authored Jan 17, 2025
1 parent 883c787 commit 438871e
Show file tree
Hide file tree
Showing 10 changed files with 377 additions and 26 deletions.
226 changes: 226 additions & 0 deletions nncf/openvino/graph/model_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
from typing import Dict, List, Tuple

import openvino.runtime as ov
from openvino.runtime import opset13 as opset
from openvino.runtime.utils.node_factory import NodeFactory

from nncf.openvino.graph.model_transformer import OVModelTransformer
from nncf.openvino.graph.node_utils import get_parameter_node_name
from nncf.openvino.graph.node_utils import get_result_node_name


class OVModelBuilder:
"""
The purpose of the ModelBuilder is to build a new OpenVINO model from input and output points.
This Builder was created to reduce the number of model cloning that is required for ModelTransformer to work.
"""

def __init__(self):
self._node_factory = NodeFactory()

@staticmethod
def _create_parameter(node_name: str, node_input: ov.Input) -> ov.Node:
"""
A method that contains steps to create a Parameter for a new model using a specific template.
"""
port_id = node_input.get_index()
parameter_name = get_parameter_node_name(node_name, port_id)
return opset.parameter(
shape=node_input.get_partial_shape(),
dtype=node_input.get_element_type(),
name=parameter_name,
)

@staticmethod
def _create_result(node_name: str, node_output: ov.Input) -> ov.Node:
"""
A method that contains steps to create a Result for a new model using a specific template.
"""
port_id = node_output.get_index()
result_name = get_result_node_name(node_name, port_id=port_id)
result = opset.result(node_output, name=result_name)
result.get_output_tensor(0).set_names({result_name})
return result

def _collect_graph_nodes(
self,
input_ids: List[Tuple[str, int]],
output_ids: List[Tuple[str, int]],
node_mapping: Dict[str, ov.Node],
) -> List[ov.Node]:
"""
A method for aggregating layers to be further cloned.
Aggregation is designed in such a way that layers are listed from right to left,
as they pass from bottom to top. This is done in order to find all constants in the model and
to start graph creation from them (as well as Parameter layers), because
OpenVINO graph is created from top-down and cannot be created otherwise.
Legend: w - weigths, c - convert, il/ih - input low/high, ol/oh - output low/high
(w)
|
(c) (il) (ih) (ol) (oh)
\ | | / /
(fake quantize) (parameter)
\ /
(convolution)
|
(result)
Based on the above graph, the return value would look like this:
[convolution, parameter, fake quantize, oh, ol, ih, il, c, w]
:param input_ids: List of the points in the special format - (node_name, port_id).
This helps to point to the precise part of the model that may be used to define the subgraph inputs.
:param output_ids: List of the points in the special format - (node_name, port_id).
This helps to point to the precise part of the model that may be used to define the subgraph outputs.
:param node_mapping: Original nodes mapping.
:return: List of the ov.Nodes to clone.
"""
# Creating a list as a deque for FIFO layer acquisition and retrieval
lookup_nodes = deque(node_mapping[n] for n, _ in output_ids)
graph_nodes = []

while lookup_nodes:
lookup_node = lookup_nodes.popleft()
lookup_name = lookup_node.get_friendly_name()
node_inputs = lookup_node.inputs()
graph_nodes.append(lookup_node)
# Reversing to lookup nodes from right to left
for node_input in reversed(node_inputs):
port_id = node_input.get_index()
if (lookup_name, port_id) in input_ids:
# We create Parameters here to avoid double creation in the future since it is not an original node,
# but we need to have it as input for next node.
parameter = self._create_parameter(lookup_name, node_input)
lookup_nodes.append(parameter)
continue
parent_node = node_input.get_source_output().get_node()
lookup_nodes.append(parent_node)

return graph_nodes

def build(
self,
input_ids: List[Tuple[str, int]],
output_ids: List[Tuple[str, int]],
node_mapping: Dict[str, ov.Node],
) -> ov.Model:
"""
The basic method of the algorithm. This method uses an aggregated list of layers to be recreated.
Let us take a graph of this kind as an example:
Legend: w - weigths, c - convert, il/ih - input low/high, ol/oh - output low/high
(w)
|
(c) (il) (ih) (ol) (oh)
\ | | / /
(fake quantize) (parameter)
\ /
(convolution)
|
(result)
The externally collected list of layers will look like this:
[convolution, parameter, fake quantize, oh, ol, ih, il, c, w]
Next, this list will be circled from right to left. At the same time, the list of already created layers
will be filled from left to right, which will be used in the traversal step also, from left to right,
in order to keep the order of the original layer inputs.
For example:
graph_nodes = [convolution, parameter, fake quantize, oh, ol, ih, il, c, w]
clone_nodes = []
*creating w - weight node.*
graph_nodes = [convolution, parameter, fake quantize, oh, ol, ih, il, c]
clone_nodes = [w]
*creating c - convert node.
Based on the .inputs() output, we'll use the already created w-weight node to fill in the convert input.
As the result, weight node would be removed from the clone_nodes list and convert node would be placed here.*
graph_nodes = [convolution, parameter, fake quantize, oh, ol, ih, il]
clone_nodes = [c]
*creating il/ih - input low/high, ol/oh - output low/high nodes.
Since these nodes are constants and do not require any nodes as inputs, cloned nodes will not be used.*
graph_nodes = [convolution, parameter, fake quantize, oh, ol, ih, il]
clone_nodes = [c, il, ih, ol, oh]
*creating fake quantize node.
This node requires to have input values in a specific order.
All previous nodes will be connected/used for fake quantize, from left to right.*
graph_nodes = [convolution, parameter]
clone_nodes = [f]
*creating parameter node.
In this step, the list of parameters will also be filled out with the new node.*
graph_nodes = [convolution]
clone_nodes = [f, parameter]
*creating convolution node.
This node also requires to have inputs in a specific order.
All previous nodes will be connected/used for convolution, from left to right. Also,
the outputs verification step will show here that one of the convolution outputs is in the output_ids list.
This means that the Result node would be created and placed into the results list.*
graph_nodes = []
clone_nodes = [convolution]
The last step is to create a subgraph model based on the parameters & results lists.
:param input_ids: List of the points in the special format - (node_name, port_id).
This helps to point to the precise part of the model that may be used to define the subgraph inputs.
:param output_ids: List of the points in the special format - (node_name, port_id).
This helps to point to the precise part of the model that may be used to define the subgraph outputs.
:param node_mapping: Original nodes mapping.
:return: Builded ov.Model based on parameters.
"""

parameters, results = [], []
clone_nodes = deque()

# Collecting nodes that declares the graph.
graph_nodes = self._collect_graph_nodes(input_ids, output_ids, node_mapping)

while graph_nodes:
graph_node = graph_nodes.pop()
node_type = graph_node.get_type_name()
node_name = graph_node.get_friendly_name()

# To create the new OpenVINO nodes, we need to provide all possible layer attributes.
attrs = graph_node.get_attributes()
attrs["name"] = node_name

if node_type == "Constant":
# Constants creation is apart due to specific behavior.
clone_node = OVModelTransformer._create_constant(
graph_node.get_data(), dtype=graph_node.get_element_type(), name=attrs["name"]
)
elif node_type == "Parameter":
# We've created Parameter nodes on the previous step.
clone_node = graph_node
parameters.append(clone_node)
else:
# We have to have args as the inputs since all of them are nodes and are required to be as input.
args = [clone_nodes.popleft() for _ in graph_node.inputs()]

clone_node = self._node_factory.create(node_type, args, attrs)

for node_output in clone_node.outputs():
port_id = node_output.get_index()
if (node_name, port_id) in output_ids:
result = self._create_result(node_name, node_output)
results.append(result)

clone_nodes.append(clone_node)

return ov.Model(results, parameters)
10 changes: 7 additions & 3 deletions nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,22 @@ def get_const_value(const_node: ov.Node) -> np.ndarray:
return const_node.data


def get_bias_value(node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> np.ndarray:
def get_bias_value(
node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model, node_mapping: Dict[str, ov.Node] = None
) -> np.ndarray:
"""
Returns the bias tensor for the biased node.
:param node_with_bias: The node that corresponds to the operation with bias.
:param nncf_graph: NNCFGraph instance.
:param model: The model that contains this operation.
:param node_mapping: Original nodes mapping cache.
:return: The bias value that is applied to the output tensor of the node's operation.
"""
ops_dict = {op.get_friendly_name(): op for op in model.get_ops()}
if node_mapping is None:
node_mapping = {op.get_friendly_name(): op for op in model.get_ops()}
bias_constant = get_node_with_bias_value(get_add_bias_node(node_with_bias, nncf_graph), nncf_graph)
ov_bias_constant = ops_dict[bias_constant.node_name]
ov_bias_constant = node_mapping[bias_constant.node_name]
return get_const_value(ov_bias_constant)


Expand Down
22 changes: 2 additions & 20 deletions nncf/quantization/algorithms/fast_bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from nncf.common.factory import EngineFactory
from nncf.common.factory import ModelTransformerFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.model_transformer import ModelTransformer
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
Expand Down Expand Up @@ -111,7 +110,7 @@ def _set_backend_entity(self, model: TModel) -> None:
OVFastBiasCorrectionAlgoBackend,
)

self._backend_entity = OVFastBiasCorrectionAlgoBackend()
self._backend_entity = OVFastBiasCorrectionAlgoBackend(model)
elif model_backend == BackendType.TORCH:
from nncf.quantization.algorithms.fast_bias_correction.torch_backend import PTFastBiasCorrectionAlgoBackend

Expand Down Expand Up @@ -167,7 +166,7 @@ def apply(
# Outputs of the subgraphs for the FastBiasCorrection are the same across the backends.
output_id = (out_node_name, 0)

extracted_model = self._extract_submodel(model_transformer, input_id, output_id)
extracted_model = self._backend_entity.extract_submodel(model_transformer, input_id, output_id)
if extracted_model is None:
nncf_logger.debug(f"Skipping node {node_name} because cant extract submodel")
continue
Expand Down Expand Up @@ -287,23 +286,6 @@ def output_filter_func(point):
output_fp.extend(tensor_collector.get_statistics().mean_values)
return output_fp

def _extract_submodel(
self, model_transformer: ModelTransformer, input_id: Tuple[str, int], output_id: Tuple[str, int]
) -> TModel:
"""
Extracts sub-model using backend-specific ModelTransformer.
:param model_transformer: Backend-specific ModelTransformer.
:param input_id: Input ID.
:param output_id: Output ID.
:return: Backend-specific sub-model.
"""
model_extraction_command = self._backend_entity.model_extraction_command([input_id], [output_id])
me_transformation_layout = TransformationLayout()
me_transformation_layout.register(model_extraction_command)
extracted_model = model_transformer.transform(me_transformation_layout)
return extracted_model

def _add_statistic_point(self, container: StatisticPointsContainer, point: TargetPoint, axis: int) -> None:
"""
Adds specific statistic point.
Expand Down
19 changes: 19 additions & 0 deletions nncf/quantization/algorithms/fast_bias_correction/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.model_transformer import ModelTransformer
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
from nncf.tensor import Tensor

Expand Down Expand Up @@ -194,3 +196,20 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple
:param input_shape: Shape of the input.
:return: Channel axis number.
"""

def extract_submodel(
self, model_transformer: ModelTransformer, input_id: Tuple[str, int], output_id: Tuple[str, int]
) -> TModel:
"""
Extracts sub-model using backend-specific ModelTransformer.
:param model_transformer: Backend-specific ModelTransformer.
:param input_id: Input ID.
:param output_id: Output ID.
:return: Backend-specific sub-model.
"""
model_extraction_command = self.model_extraction_command([input_id], [output_id])
me_transformation_layout = TransformationLayout()
me_transformation_layout.register(model_extraction_command)
extracted_model = model_transformer.transform(me_transformation_layout)
return extracted_model
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.openvino.graph.metatypes.groups import FAKE_QUANTIZE_OPERATIONS
from nncf.openvino.graph.metatypes.groups import OPERATIONS_WITH_BIAS_REDUCED
from nncf.openvino.graph.model_builder import OVModelBuilder
from nncf.openvino.graph.node_utils import get_activation_channel_axis
from nncf.openvino.graph.node_utils import get_bias_value
from nncf.openvino.graph.node_utils import is_node_with_bias
Expand All @@ -33,6 +34,12 @@


class OVFastBiasCorrectionAlgoBackend(FastBiasCorrectionAlgoBackend):

def __init__(self, model):
# Node mapping caching to reduce time for calculations
self._node_mapping = {op.get_friendly_name(): op for op in model.get_ops()}
self._model_builder = OVModelBuilder()

@staticmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint:
return OVTargetPoint(target_type, target_node_name, port_id)
Expand Down Expand Up @@ -73,9 +80,8 @@ def create_input_data(
input_data = {input_name: blob}
return input_data

@staticmethod
def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> Tensor:
return Tensor(get_bias_value(node, nncf_graph, model))
def get_bias_value(self, node: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> Tensor:
return Tensor(get_bias_value(node, nncf_graph, model, node_mapping=self._node_mapping))

@staticmethod
def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]:
Expand Down Expand Up @@ -113,3 +119,11 @@ def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFG
@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int:
return get_activation_channel_axis(node, port_id, input_shape)

def extract_submodel(self, model_transformer, input_id, output_id):

return self._model_builder.build(
input_ids=[input_id],
output_ids=[output_id],
node_mapping=self._node_mapping,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
strict digraph {
"0 Parameter_Conv.0" [id=0, type=Parameter];
"1 Convolution_57" [id=1, type=Convolution];
"2 Result_Conv.0" [id=2, type=Result];
"3 Conv/Constant_4" [id=3, type=Constant];
"0 Parameter_Conv.0" -> "1 Convolution_57" [label="[1, 3, 4, 2]", style=solid];
"1 Convolution_57" -> "2 Result_Conv.0" [label="[1, 3, 4, 2]", style=solid];
"3 Conv/Constant_4" -> "1 Convolution_57" [label="[3, 3, 1, 1]", style=solid];
}
Loading

0 comments on commit 438871e

Please sign in to comment.