Skip to content

Commit

Permalink
Disable CA algo by default, enable all biases insertion
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jul 4, 2023
1 parent 434ef02 commit 09231a3
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 16 deletions.
4 changes: 3 additions & 1 deletion nncf/quantization/advanced_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class AdvancedQuantizationParameters:
:param inplace_statistics: Defines whether to calculate quantizers statistics by
backend graph operations or by default Python implementation, defaults to True.
:type inplace_statistics: bool
:param disable_channel_alignment: Whether to disable the channel alignment.
:type disable_channel_alignment: bool
:param disable_bias_correction: Whether to disable the bias correction.
:type disable_bias_correction: bool
:param smooth_quant_alpha: SmoothQuant-related parameter. It regulates the calculation of the smooth scale.
Expand All @@ -150,7 +152,7 @@ class AdvancedQuantizationParameters:
overflow_fix: OverflowFix = OverflowFix.FIRST_LAYER
quantize_outputs: bool = False
inplace_statistics: bool = True
disable_channel_alignment: bool = False
disable_channel_alignment: bool = True
disable_bias_correction: bool = False
smooth_quant_alpha: float = 0.95

Expand Down
13 changes: 11 additions & 2 deletions nncf/quantization/algorithms/channel_alignment/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVSubtractMetatype
from nncf.openvino.graph.node_utils import get_bias_value
from nncf.openvino.graph.node_utils import get_node_with_bias_value
from nncf.openvino.graph.node_utils import get_weight_value
from nncf.openvino.graph.node_utils import is_node_with_bias
from nncf.openvino.graph.transformations.command_creation import OVCommandCreator
from nncf.openvino.graph.transformations.commands import OVTargetPoint
from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand
Expand Down Expand Up @@ -91,7 +91,16 @@ def get_statistic_collector(

@staticmethod
def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
return is_node_with_bias(node, nncf_graph)
next_nodes = nncf_graph.get_next_nodes(node)
if not next_nodes:
return False

add_node = next_nodes[0]
if add_node.metatype != OVAddMetatype:
return False

bias_constant = get_node_with_bias_value(add_node, nncf_graph)
return bias_constant is not None

@staticmethod
def create_bias_update_command(
Expand Down
22 changes: 17 additions & 5 deletions nncf/quantization/algorithms/post_training/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Dict, Optional, TypeVar
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, TypeVar

from nncf import Dataset
from nncf.common.logging import nncf_logger
Expand All @@ -30,6 +31,7 @@
from nncf.quantization.algorithms.fast_bias_correction.algorithm import FastBiasCorrection
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
from nncf.quantization.algorithms.smooth_quant.algorithm import SmoothQuant
from nncf.quantization.passes import insert_null_biases_pass
from nncf.scopes import IgnoredScope

TModel = TypeVar("TModel")
Expand All @@ -44,6 +46,11 @@ class PostTrainingQuantization(Algorithm):
3) FastBiasCorrection or BiasCorrection
"""

@dataclass
class FirstStageAlgorithm:
algorithm: "Algorithm"
pre_passes: List[TPass]

def __init__(
self,
preset: QuantizationPreset = QuantizationPreset.PERFORMANCE,
Expand Down Expand Up @@ -77,7 +84,7 @@ def __init__(
"""
super().__init__()
self.algorithms = []
self.first_stage_algorithms = []
self.first_stage_algorithms: List[self.FirstStageAlgorithm] = []

if advanced_parameters is None:
advanced_parameters = AdvancedQuantizationParameters()
Expand All @@ -88,15 +95,15 @@ def __init__(
inplace_statistics=advanced_parameters.inplace_statistics,
alpha=advanced_parameters.smooth_quant_alpha,
)
self.first_stage_algorithms.append(smooth_quant_algorithm)
self.first_stage_algorithms.append(self.FirstStageAlgorithm(smooth_quant_algorithm, []))

if not advanced_parameters.disable_channel_alignment:
channel_alignment = ChannelAlignment(
subset_size=subset_size,
inplace_statistics=advanced_parameters.inplace_statistics,
backend_params=advanced_parameters.backend_params,
)
self.first_stage_algorithms.append(channel_alignment)
self.first_stage_algorithms.append(self.FirstStageAlgorithm(channel_alignment, [insert_null_biases_pass]))

min_max_quantization = MinMaxQuantization(
preset=preset,
Expand Down Expand Up @@ -192,7 +199,9 @@ def _apply(
backend = get_backend(modified_model)

if statistic_points is None:
for algorithm in self.first_stage_algorithms:
for first_stage_algorithm in self.first_stage_algorithms:
algorithm = first_stage_algorithm.algorithm

if isinstance(algorithm, SmoothQuant) and backend != BackendType.OPENVINO:
nncf_logger.debug(f"{backend.name} does not support SmoothQuant algorithm yet.")
continue
Expand All @@ -201,6 +210,9 @@ def _apply(
nncf_logger.debug(f"{backend.name} does not support ChannelAlignment algorithm yet.")
continue

for pre_pass in first_stage_algorithm.pre_passes:
modified_model = pre_pass(modified_model)

statistics_aggregator = self._create_statistics_aggregator(dataset, backend)
algo_statistic_points = algorithm.get_statistic_points(modified_model)
statistics_aggregator.register_statistic_points(algo_statistic_points)
Expand Down
22 changes: 21 additions & 1 deletion nncf/quantization/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
# limitations under the License.

import collections
from typing import List, Optional
from typing import List, Optional, TypeVar

from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend


def transform_to_inference_graph(
Expand Down Expand Up @@ -116,3 +118,21 @@ def filter_constant_nodes(
constant_nodes = [node for node in nncf_graph.get_all_nodes() if node not in visited_nodes]
nncf_graph.remove_nodes_from(constant_nodes)
return nncf_graph


TModel = TypeVar("TModel")


def insert_null_biases_pass(model: TModel) -> TModel:
"""
This pass finds and inserts zero biases to the given model for the layers that should have it.
:param model: Model instance.
:return: Updated Model instance with zero biases
"""
model_backend = get_backend(model)
if model_backend == BackendType.OPENVINO:
from nncf.openvino.graph.model_utils import insert_null_biases

return insert_null_biases(model)
return model
31 changes: 24 additions & 7 deletions tests/openvino/native/quantization/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nncf.common.quantization.structs import QuantizationPreset
from nncf.common.utils.os import is_windows
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from tests.openvino.conftest import AC_CONFIGS_DIR
from tests.openvino.datasets_helpers import get_dataset_for_test
from tests.openvino.datasets_helpers import get_nncf_dataset_from_ac_config
Expand All @@ -26,16 +27,28 @@
from tests.openvino.omz_helpers import download_model

OMZ_MODELS = [
("resnet-18-pytorch", "imagenette2-320", {"accuracy@top1": 0.777, "accuracy@top5": 0.948}),
("mobilenet-v3-small-1.0-224-tf", "imagenette2-320", {"accuracy@top1": 0.744, "accuracy@top5": 0.916}),
("googlenet-v3-pytorch", "imagenette2-320", {"accuracy@top1": 0.911, "accuracy@top5": 0.994}),
("mobilefacedet-v1-mxnet", "wider", {"map": 0.7763171885846742}),
("retinaface-resnet50-pytorch", "wider", {"map": 0.917961898320335}),
(
"resnet-18-pytorch",
"imagenette2-320",
{"accuracy@top1": 0.777, "accuracy@top5": 0.948},
None,
),
(
"mobilenet-v3-small-1.0-224-tf",
"imagenette2-320",
{"accuracy@top1": 0.744, "accuracy@top5": 0.916},
AdvancedQuantizationParameters(disable_channel_alignment=False),
),
("googlenet-v3-pytorch", "imagenette2-320", {"accuracy@top1": 0.911, "accuracy@top5": 0.994}, None),
("mobilefacedet-v1-mxnet", "wider", {"map": 0.7763171885846742}, None),
("retinaface-resnet50-pytorch", "wider", {"map": 0.917961898320335}, None),
]


@pytest.mark.parametrize("model, dataset, ref_metrics", OMZ_MODELS, ids=[model[0] for model in OMZ_MODELS])
def test_compression(data_dir, tmp_path, model, dataset, ref_metrics):
@pytest.mark.parametrize(
"model, dataset, ref_metrics, advanced_params", OMZ_MODELS, ids=[model[0] for model in OMZ_MODELS]
)
def test_compression(data_dir, tmp_path, model, dataset, ref_metrics, advanced_params):
if is_windows() and model == "mobilefacedet-v1-mxnet":
pytest.xfail("OMZ for Windows has version 1.2.0 pinned that is incompatible with Python 3.8+")
extracted_data_dir = os.path.dirname(get_dataset_for_test(dataset, data_dir))
Expand All @@ -51,13 +64,17 @@ def test_compression(data_dir, tmp_path, model, dataset, ref_metrics):
calibration_dataset = get_nncf_dataset_from_ac_config(model_path, config_path, extracted_data_dir)

ov_model = ov.Core().read_model(str(model_path))

if advanced_params is None:
advanced_params = AdvancedQuantizationParameters()
quantized_model = nncf.quantize(
ov_model,
calibration_dataset,
QuantizationPreset.PERFORMANCE,
TargetDevice.ANY,
subset_size=300,
fast_bias_correction=True,
advanced_parameters=advanced_params,
)
ov.serialize(quantized_model, int8_ir_path)

Expand Down

0 comments on commit 09231a3

Please sign in to comment.