diff --git a/tests/common/accuracy_control/backend.py b/tests/common/accuracy_control/backend.py new file mode 100644 index 00000000000..adccac2d0c4 --- /dev/null +++ b/tests/common/accuracy_control/backend.py @@ -0,0 +1,86 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List +from abc import abstractmethod +from typing import Any, List, Optional, TypeVar + + +import numpy as np +import pytest + +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.quantization.algorithms.accuracy_control.backend import AccuracyControlAlgoBackend +from nncf.quantization.algorithms.accuracy_control.backend import TModel +from nncf.quantization.algorithms.accuracy_control.rank_functions import normalized_mse +from nncf.quantization.algorithms.accuracy_control.ranker import Ranker +from nncf.quantization.algorithms.accuracy_control.subset_selection import get_subset_indices +from tests.common.quantization.metatypes import CONSTANT_METATYPES +from tests.common.quantization.metatypes import METATYPES_FOR_TEST +from tests.common.quantization.metatypes import QUANTIZABLE_METATYPES +from tests.common.quantization.metatypes import QUANTIZE_AGNOSTIC_METATYPES +from tests.common.quantization.metatypes import QUANTIZER_METATYPES +from tests.common.quantization.metatypes import ShapeOfTestMetatype +from tests.common.quantization.test_quantizer_removal import TestCase +from tests.common.quantization.test_quantizer_removal import create_test_params + +class AABackendForTests(AccuracyControlAlgoBackend): + @staticmethod + def get_quantizer_metatypes() -> List[OperatorMetatype]: + return QUANTIZER_METATYPES + + @staticmethod + def get_const_metatypes() -> List[OperatorMetatype]: + return CONSTANT_METATYPES + + @staticmethod + def get_quantizable_metatypes() -> List[OperatorMetatype]: + return QUANTIZABLE_METATYPES + + @staticmethod + def get_start_nodes_for_activation_path_tracing(nncf_graph: NNCFGraph) -> List[NNCFNode]: + return nncf_graph.get_input_nodes() + + @staticmethod + def get_quantize_agnostic_metatypes() -> List[OperatorMetatype]: + return QUANTIZE_AGNOSTIC_METATYPES + + @staticmethod + def get_shapeof_metatypes() -> List[OperatorMetatype]: + return [ShapeOfTestMetatype] + + @staticmethod + def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + return False + + @staticmethod + def is_node_with_weight(node: NNCFNode) -> bool: + return False + + @staticmethod + def get_bias_value(node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: TModel) -> Any: + return None + @staticmethod + def get_weight_value(node_with_weight: NNCFNode, model: TModel, port_id: int) -> Any: + return None + @staticmethod + def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]: + return None + + @staticmethod + def get_model_size(model: TModel) -> int: + return 0 + + @staticmethod + def prepare_for_inference(model: TModel) -> TModel: + return model diff --git a/tests/common/accuracy_control/test_ranking.py b/tests/common/accuracy_control/test_ranking.py index a383f8281cc..6a153c9cf20 100644 --- a/tests/common/accuracy_control/test_ranking.py +++ b/tests/common/accuracy_control/test_ranking.py @@ -9,13 +9,20 @@ # See the License for the specific language governing permissions and # limitations under the License. + from typing import List import numpy as np import pytest +from nncf.common.graph.graph import NNCFGraph from nncf.quantization.algorithms.accuracy_control.rank_functions import normalized_mse +from nncf.quantization.algorithms.accuracy_control.ranker import GroupToRank +from nncf.quantization.algorithms.accuracy_control.ranker import Ranker from nncf.quantization.algorithms.accuracy_control.subset_selection import get_subset_indices +from tests.common.accuracy_control.backend import AABackendForTests +from tests.common.quantization.test_quantizer_removal import GRAPHS as AA_GRAPHS_DESCR +from tests.common.quantization.test_quantizer_removal import create_nncf_graph as aa_create_nncf_graph def create_fp32_tensor_1d(items): @@ -77,3 +84,42 @@ def test_normalized_mse(x_ref: np.ndarray, x_approx: np.ndarray, expected_nmse: def test_get_subset_indices(errors: List[float], subset_size: int, expected_indices: List[int]): actual_indices = get_subset_indices(errors, subset_size) assert expected_indices == actual_indices + + +@pytest.mark.parametrize( + "nncf_graph_name,ref_groups", + [ + ( + "simple_graph", + [ + GroupToRank(["quantizer_139", "quantizer_162", "quantizer_119"], ["add_117", "conv2d_161"]), + GroupToRank(["quantizer_153", "quantizer_147"], ["conv2d_146"]), + GroupToRank(["quantizer_134", "quantizer_128"], ["conv2d_127"]), + ], + ), + ( + "graph_with_shapeof", + [ + GroupToRank(["quantizer_105"], ["interpolate_115"]), + GroupToRank(["quantizer_710", "quantizer_93"], ["multiply_99"]), + GroupToRank(["quantizer_82"], ["power_87"]), + ], + ), + ], +) +def test_find_groups_of_quantizers_to_rank(nncf_graph_name: NNCFGraph, ref_groups: List[GroupToRank]): + ranker = Ranker(1, tuple(), AABackendForTests, None) + nncf_graph = aa_create_nncf_graph(AA_GRAPHS_DESCR[nncf_graph_name]) + ret_val = ranker.find_groups_of_quantizers_to_rank(nncf_graph) + # Check ret_val + assert len(ret_val) == len(ref_groups) + # Can zip as qauantizers are topologically sorted + for actual_group, ref_group in zip(ret_val, ref_groups): + for attr in ["quantizers", "operations"]: + acutal_attr_value = getattr(actual_group, attr) + ref_attr_value = getattr(ref_group, attr) + + assert len(acutal_attr_value) == len(ref_attr_value) + actual_node_names = [n.node_name for n in acutal_attr_value] + for ref_node_name in ref_attr_value: + assert ref_node_name in actual_node_names