Skip to content

Commit

Permalink
test_find_groups_of_quantizers_to_rank is presented
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Dec 6, 2023
1 parent 21dbece commit 24cbc33
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 0 deletions.
86 changes: 86 additions & 0 deletions tests/common/accuracy_control/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from abc import abstractmethod
from typing import Any, List, Optional, TypeVar


import numpy as np
import pytest

from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.quantization.algorithms.accuracy_control.backend import AccuracyControlAlgoBackend
from nncf.quantization.algorithms.accuracy_control.backend import TModel
from nncf.quantization.algorithms.accuracy_control.rank_functions import normalized_mse
from nncf.quantization.algorithms.accuracy_control.ranker import Ranker
from nncf.quantization.algorithms.accuracy_control.subset_selection import get_subset_indices
from tests.common.quantization.metatypes import CONSTANT_METATYPES
from tests.common.quantization.metatypes import METATYPES_FOR_TEST
from tests.common.quantization.metatypes import QUANTIZABLE_METATYPES
from tests.common.quantization.metatypes import QUANTIZE_AGNOSTIC_METATYPES
from tests.common.quantization.metatypes import QUANTIZER_METATYPES
from tests.common.quantization.metatypes import ShapeOfTestMetatype
from tests.common.quantization.test_quantizer_removal import TestCase
from tests.common.quantization.test_quantizer_removal import create_test_params

class AABackendForTests(AccuracyControlAlgoBackend):
@staticmethod
def get_quantizer_metatypes() -> List[OperatorMetatype]:
return QUANTIZER_METATYPES

@staticmethod
def get_const_metatypes() -> List[OperatorMetatype]:
return CONSTANT_METATYPES

@staticmethod
def get_quantizable_metatypes() -> List[OperatorMetatype]:
return QUANTIZABLE_METATYPES

@staticmethod
def get_start_nodes_for_activation_path_tracing(nncf_graph: NNCFGraph) -> List[NNCFNode]:
return nncf_graph.get_input_nodes()

@staticmethod
def get_quantize_agnostic_metatypes() -> List[OperatorMetatype]:
return QUANTIZE_AGNOSTIC_METATYPES

@staticmethod
def get_shapeof_metatypes() -> List[OperatorMetatype]:
return [ShapeOfTestMetatype]

@staticmethod
def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
return False

@staticmethod
def is_node_with_weight(node: NNCFNode) -> bool:
return False

@staticmethod
def get_bias_value(node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: TModel) -> Any:
return None
@staticmethod
def get_weight_value(node_with_weight: NNCFNode, model: TModel, port_id: int) -> Any:
return None
@staticmethod
def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]:
return None

@staticmethod
def get_model_size(model: TModel) -> int:
return 0

@staticmethod
def prepare_for_inference(model: TModel) -> TModel:
return model
46 changes: 46 additions & 0 deletions tests/common/accuracy_control/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List

import numpy as np
import pytest

from nncf.common.graph.graph import NNCFGraph
from nncf.quantization.algorithms.accuracy_control.rank_functions import normalized_mse
from nncf.quantization.algorithms.accuracy_control.ranker import GroupToRank
from nncf.quantization.algorithms.accuracy_control.ranker import Ranker
from nncf.quantization.algorithms.accuracy_control.subset_selection import get_subset_indices
from tests.common.accuracy_control.backend import AABackendForTests
from tests.common.quantization.test_quantizer_removal import GRAPHS as AA_GRAPHS_DESCR
from tests.common.quantization.test_quantizer_removal import create_nncf_graph as aa_create_nncf_graph


def create_fp32_tensor_1d(items):
Expand Down Expand Up @@ -77,3 +84,42 @@ def test_normalized_mse(x_ref: np.ndarray, x_approx: np.ndarray, expected_nmse:
def test_get_subset_indices(errors: List[float], subset_size: int, expected_indices: List[int]):
actual_indices = get_subset_indices(errors, subset_size)
assert expected_indices == actual_indices


@pytest.mark.parametrize(
"nncf_graph_name,ref_groups",
[
(
"simple_graph",
[
GroupToRank(["quantizer_139", "quantizer_162", "quantizer_119"], ["add_117", "conv2d_161"]),
GroupToRank(["quantizer_153", "quantizer_147"], ["conv2d_146"]),
GroupToRank(["quantizer_134", "quantizer_128"], ["conv2d_127"]),
],
),
(
"graph_with_shapeof",
[
GroupToRank(["quantizer_105"], ["interpolate_115"]),
GroupToRank(["quantizer_710", "quantizer_93"], ["multiply_99"]),
GroupToRank(["quantizer_82"], ["power_87"]),
],
),
],
)
def test_find_groups_of_quantizers_to_rank(nncf_graph_name: NNCFGraph, ref_groups: List[GroupToRank]):
ranker = Ranker(1, tuple(), AABackendForTests, None)
nncf_graph = aa_create_nncf_graph(AA_GRAPHS_DESCR[nncf_graph_name])
ret_val = ranker.find_groups_of_quantizers_to_rank(nncf_graph)
# Check ret_val
assert len(ret_val) == len(ref_groups)
# Can zip as qauantizers are topologically sorted
for actual_group, ref_group in zip(ret_val, ref_groups):
for attr in ["quantizers", "operations"]:
acutal_attr_value = getattr(actual_group, attr)
ref_attr_value = getattr(ref_group, attr)

assert len(acutal_attr_value) == len(ref_attr_value)
actual_node_names = [n.node_name for n in acutal_attr_value]
for ref_node_name in ref_attr_value:
assert ref_node_name in actual_node_names

0 comments on commit 24cbc33

Please sign in to comment.