From 83b9a78c02df2c903e414d97e92c9a0cbcf3e405 Mon Sep 17 00:00:00 2001 From: Alexander Suslov Date: Thu, 3 Aug 2023 15:55:55 +0400 Subject: [PATCH] fixed rebase issues --- tests/onnx/quantization/test_bias_correction.py | 4 +++- tests/openvino/native/test_bias_correction.py | 4 +++- tests/post_training/test_templates/test_bias_correction.py | 3 ++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/onnx/quantization/test_bias_correction.py b/tests/onnx/quantization/test_bias_correction.py index a6eb61cd15c..2e09c072868 100644 --- a/tests/onnx/quantization/test_bias_correction.py +++ b/tests/onnx/quantization/test_bias_correction.py @@ -18,6 +18,7 @@ from nncf.common.factory import NNCFGraphFactory from nncf.onnx.graph.model_utils import remove_fq_from_inputs +from nncf.onnx.graph.nncf_graph_builder import GraphConverter from nncf.onnx.graph.node_utils import get_bias_value from nncf.quantization.algorithms.bias_correction.onnx_backend import ONNXBiasCorrectionAlgoBackend from tests.onnx.quantization.common import compare_nncf_graph @@ -62,7 +63,8 @@ def transform_fn(data_item): @staticmethod def remove_fq_from_inputs(model: onnx.ModelProto) -> onnx.ModelProto: - return remove_fq_from_inputs(model) + graph = GraphConverter.create_nncf_graph(model) + return remove_fq_from_inputs(model, graph) @staticmethod def get_ref_path(suffix: str) -> str: diff --git a/tests/openvino/native/test_bias_correction.py b/tests/openvino/native/test_bias_correction.py index 7339fd9e7ba..887a09040be 100644 --- a/tests/openvino/native/test_bias_correction.py +++ b/tests/openvino/native/test_bias_correction.py @@ -18,6 +18,7 @@ from nncf.common.factory import NNCFGraphFactory from nncf.openvino.graph.model_utils import remove_fq_from_inputs +from nncf.openvino.graph.nncf_graph_builder import GraphConverter from nncf.openvino.graph.node_utils import get_bias_value from nncf.quantization.algorithms.bias_correction.openvino_backend import OVBiasCorrectionAlgoBackend from tests.openvino.conftest import OPENVINO_NATIVE_TEST_ROOT @@ -65,7 +66,8 @@ def map_references(ref_biases: Dict) -> Dict[str, List]: @staticmethod def remove_fq_from_inputs(model: ov.Model) -> ov.Model: - return remove_fq_from_inputs(model) + graph = GraphConverter.create_nncf_graph(model) + return remove_fq_from_inputs(model, graph) @staticmethod def get_ref_path(suffix: str) -> str: diff --git a/tests/post_training/test_templates/test_bias_correction.py b/tests/post_training/test_templates/test_bias_correction.py index e7cc7e57ca2..68c72301707 100644 --- a/tests/post_training/test_templates/test_bias_correction.py +++ b/tests/post_training/test_templates/test_bias_correction.py @@ -134,7 +134,8 @@ def quantized_test_model(self, tmpdir) -> TModel: dataset = Dataset(self.get_dataset(model_cls.INPUT_SIZE), self.get_transform_fn()) quantization_algorithm = self.get_quantization_algorithm(disable_bias_correction=True) - quantized_model = quantization_algorithm.apply(model, dataset=dataset) + graph = NNCFGraphFactory.create(model) + quantized_model = quantization_algorithm.apply(model, graph, dataset=dataset) modified_model = self.remove_fq_from_inputs(quantized_model) return modified_model