Skip to content

Commit

Permalink
fixed rebase issues
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsu52 committed Aug 3, 2023
1 parent 7bc0e3d commit 83b9a78
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
4 changes: 3 additions & 1 deletion tests/onnx/quantization/test_bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from nncf.common.factory import NNCFGraphFactory
from nncf.onnx.graph.model_utils import remove_fq_from_inputs
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
from nncf.onnx.graph.node_utils import get_bias_value
from nncf.quantization.algorithms.bias_correction.onnx_backend import ONNXBiasCorrectionAlgoBackend
from tests.onnx.quantization.common import compare_nncf_graph
Expand Down Expand Up @@ -62,7 +63,8 @@ def transform_fn(data_item):

@staticmethod
def remove_fq_from_inputs(model: onnx.ModelProto) -> onnx.ModelProto:
return remove_fq_from_inputs(model)
graph = GraphConverter.create_nncf_graph(model)
return remove_fq_from_inputs(model, graph)

@staticmethod
def get_ref_path(suffix: str) -> str:
Expand Down
4 changes: 3 additions & 1 deletion tests/openvino/native/test_bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from nncf.common.factory import NNCFGraphFactory
from nncf.openvino.graph.model_utils import remove_fq_from_inputs
from nncf.openvino.graph.nncf_graph_builder import GraphConverter
from nncf.openvino.graph.node_utils import get_bias_value
from nncf.quantization.algorithms.bias_correction.openvino_backend import OVBiasCorrectionAlgoBackend
from tests.openvino.conftest import OPENVINO_NATIVE_TEST_ROOT
Expand Down Expand Up @@ -65,7 +66,8 @@ def map_references(ref_biases: Dict) -> Dict[str, List]:

@staticmethod
def remove_fq_from_inputs(model: ov.Model) -> ov.Model:
return remove_fq_from_inputs(model)
graph = GraphConverter.create_nncf_graph(model)
return remove_fq_from_inputs(model, graph)

@staticmethod
def get_ref_path(suffix: str) -> str:
Expand Down
3 changes: 2 additions & 1 deletion tests/post_training/test_templates/test_bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def quantized_test_model(self, tmpdir) -> TModel:
dataset = Dataset(self.get_dataset(model_cls.INPUT_SIZE), self.get_transform_fn())

quantization_algorithm = self.get_quantization_algorithm(disable_bias_correction=True)
quantized_model = quantization_algorithm.apply(model, dataset=dataset)
graph = NNCFGraphFactory.create(model)
quantized_model = quantization_algorithm.apply(model, graph, dataset=dataset)
modified_model = self.remove_fq_from_inputs(quantized_model)
return modified_model

Expand Down

0 comments on commit 83b9a78

Please sign in to comment.