diff --git a/nncf/common/quantization/initialization/range.py b/nncf/common/quantization/initialization/range.py index b4a98565ceb..c21ec82b5c5 100644 --- a/nncf/common/quantization/initialization/range.py +++ b/nncf/common/quantization/initialization/range.py @@ -8,8 +8,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from nncf.common.graph.utils import get_reduction_axes from nncf.common.initialization.dataloader import NNCFDataLoader @@ -26,7 +27,12 @@ class RangeInitConfig: parameters. """ - def __init__(self, init_type: str, num_init_samples: int, init_type_specific_params: Dict = None): + def __init__( + self, + init_type: str, + num_init_samples: int, + init_type_specific_params: Optional[Dict[str, int]] = None, + ): """ Initializes the quantization range initialization parameters. @@ -43,11 +49,11 @@ def __init__(self, init_type: str, num_init_samples: int, init_type_specific_par if self.init_type_specific_params is None: self.init_type_specific_params = {} - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return self.__dict__ == other.__dict__ @classmethod - def from_dict(cls, dct: Dict) -> "RangeInitConfig": + def from_dict(cls, dct: Dict[str, Any]) -> RangeInitConfig: num_init_samples = dct.get("num_init_samples", NUM_INIT_SAMPLES) if num_init_samples < 0: raise ValueError("Number of initialization samples must be >= 0") @@ -94,10 +100,10 @@ def __init__( self.target_group = target_quantizer_group @classmethod - def from_dict(cls, dct: Dict) -> "PerLayerRangeInitConfig": + def from_dict(cls, dct: Dict[str, Any]) -> PerLayerRangeInitConfig: base_config = RangeInitConfig.from_dict(dct) - def get_list(dct: Dict, attr_name: str) -> Optional[List[str]]: + def get_list(dct: Dict[str, Any], attr_name: str) -> Optional[List[str]]: str_or_list = dct.get(attr_name) if str_or_list is None: return None @@ -185,7 +191,7 @@ def is_per_channel(self) -> bool: """ return self._is_per_channel - def use_per_sample_stats(self, per_sample_stats) -> bool: + def use_per_sample_stats(self, per_sample_stats: bool) -> bool: """ For activations, if per_sample_stats is True, statistics will be collected per-sample. For weights statistics are always collected per-batch. @@ -213,7 +219,7 @@ def _get_reduction_axes( shape_to_reduce: Union[Tuple[int, ...], List[int]], quantization_axes: Union[Tuple[int, ...], List[int]], aggregation_axes: Union[Tuple[int, ...], List[int]], - ): + ) -> Tuple[int, ...]: """ Returns axes for a reducer regarding aggregation axes. As aggregator takes axes counting from stacked tensors, from these axes only tensor related axes should be used for reducer. @@ -225,7 +231,7 @@ def _get_reduction_axes( """ axes_to_keep = set(el - 1 for el in aggregation_axes if el != 0) axes_to_keep.update(quantization_axes) - return get_reduction_axes(axes_to_keep, shape_to_reduce) + return get_reduction_axes(list(axes_to_keep), shape_to_reduce) def _get_aggregation_axes(self, batchwise_statistics: bool) -> Tuple[int, ...]: """ diff --git a/nncf/common/quantization/quantizer_propagation/graph.py b/nncf/common/quantization/quantizer_propagation/graph.py index e23e2a21bf5..15c6d4e7f3f 100644 --- a/nncf/common/quantization/quantizer_propagation/graph.py +++ b/nncf/common/quantization/quantizer_propagation/graph.py @@ -13,9 +13,9 @@ from copy import copy from copy import deepcopy from dataclasses import dataclass -from typing import Any, Callable, Deque, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Deque, Dict, List, Optional, Set, Tuple, Type, Union, cast -import networkx as nx +import networkx as nx # type: ignore[import-untyped] import nncf from nncf import nncf_logger @@ -49,7 +49,7 @@ from nncf.common.scopes import should_consider_scope -class QuantizerPropagationStateGraph(nx.DiGraph): +class QuantizerPropagationStateGraph(nx.DiGraph): # type: ignore[misc] """ This class is based upon InsertionPointGraph and represents a"chessboard" for PropagatingQuantizer items. It tracks the current state of @@ -77,14 +77,16 @@ class QuantizerPropagationStateGraph(nx.DiGraph): def __init__( self, ip_graph: InsertionPointGraph, - ignored_scopes: Dict[str, IgnoreReason] = None, + ignored_scopes: Optional[Dict[str, IgnoreReason]] = None, target_scopes: List[str] = None, ): super().__init__() ip_graph = deepcopy(ip_graph) self._created_prop_quantizer_counter = 0 + if ignored_scopes is None: + ignored_scopes = {} - self._ignored_scopes = list(ignored_scopes.keys()) if ignored_scopes is not None else None + self._ignored_scopes = list(ignored_scopes.keys()) self._target_scopes = deepcopy(target_scopes) self.ignored_node_keys: Dict[str, IgnoreReason] = {} @@ -96,7 +98,7 @@ def __init__( iteration_scope_node_keys = [] for node_key, node in ip_graph.nodes.items(): - qpg_node = { + qpg_node: Dict[str, Any] = { self.NODE_TYPE_NODE_ATTR: self.ipg_node_type_to_qpsg_node_type( node[InsertionPointGraph.NODE_TYPE_NODE_ATTR] ) @@ -164,7 +166,7 @@ def __init__( for barred_node_key in list(self.ignored_node_keys.keys()) + iteration_scope_node_keys: self._add_barrier_after_node(barred_node_key) - self._branch_nodes_directly_dominating_outputs = None + self._branch_nodes_directly_dominating_outputs: Optional[Set[str]] = None def get_input_node_keys(self) -> List[str]: """ @@ -172,7 +174,7 @@ def get_input_node_keys(self) -> List[str]: :return: List of the input node keys. """ - return self._input_node_keys_vs_nncf_nodes.keys() + return list(self._input_node_keys_vs_nncf_nodes.keys()) def get_node_keys_by_metatype(self, metatype: Type[OperatorMetatype]) -> List[str]: """ @@ -196,7 +198,7 @@ def _insertion_point_to_quant_insertion_point( assert isinstance(ip, PostHookInsertionPoint) return ActivationQuantizationInsertionPoint(ip.target_node_name, input_port_id=None) - def _add_barrier_after_node(self, node_key: str): + def _add_barrier_after_node(self, node_key: str) -> None: qpg_node_barrier = { self.NODE_TYPE_NODE_ATTR: QuantizerPropagationStateGraphNodeType.AUXILIARY_BARRIER, "label": QuantizerPropagationStateGraph.BARRIER_NODE_KEY_POSTFIX, @@ -227,7 +229,7 @@ def ipg_node_type_to_qpsg_node_type( def get_barrier_node_key(node_key: str) -> str: return f"{QuantizerPropagationStateGraph.BARRIER_NODE_KEY_POSTFIX} {node_key}" - def mark_act_quantizer_as_dependent_on_weights(self, pq: PropagatingQuantizer, operator_node_key: str): + def mark_act_quantizer_as_dependent_on_weights(self, pq: PropagatingQuantizer, operator_node_key: str) -> None: """ Marks a given propagating quantizer corresponding to input activation quantization of some downstream op as dependent on weights of an operation that gives its weights directly @@ -266,13 +268,15 @@ def is_insertion_point(qpsg_node_type: QuantizerPropagationStateGraphNodeType) - QuantizerPropagationStateGraphNodeType.POST_HOOK, ] - def merge_quantizer_into_path(self, prop_quantizer: PropagatingQuantizer, path: PropagationPath): + def merge_quantizer_into_path(self, prop_quantizer: PropagatingQuantizer, path: PropagationPath) -> None: curr_node = self.nodes[prop_quantizer.current_location_node_key] curr_node[QuantizerPropagationStateGraph.PROPAGATING_QUANTIZER_NODE_ATTR] = None surviving_quantizers: List[PropagatingQuantizer] = [] for from_node_key, to_node_key in path: edge = self.edges[from_node_key, to_node_key] - edge_affecting_quantizers = edge[QuantizerPropagationStateGraph.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] + edge_affecting_quantizers = cast( + List[PropagatingQuantizer], edge[QuantizerPropagationStateGraph.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] + ) if edge_affecting_quantizers: surviving_quantizers = copy(edge_affecting_quantizers) break @@ -282,7 +286,9 @@ def merge_quantizer_into_path(self, prop_quantizer: PropagatingQuantizer, path: from_node = self.nodes[from_node_key] from_node_type = from_node[QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR] if self.is_insertion_point(from_node_type): - node_propagating_quantizer = from_node[QuantizerPropagationStateGraph.PROPAGATING_QUANTIZER_NODE_ATTR] + node_propagating_quantizer = cast( + PropagatingQuantizer, from_node[QuantizerPropagationStateGraph.PROPAGATING_QUANTIZER_NODE_ATTR] + ) if node_propagating_quantizer is not None: surviving_quantizers = [node_propagating_quantizer] break @@ -313,12 +319,14 @@ def merge_quantizer_into_path(self, prop_quantizer: PropagatingQuantizer, path: if prop_quantizer.unified_scale_type is not None: gid = self._unified_scale_group_manager.get_group_id_by_propagating_quantizer_id(prop_quantizer.id) + assert gid is not None for other_pq in surviving_quantizers: if other_pq.unified_scale_type is not None: other_gid = self._unified_scale_group_manager.get_group_id_by_propagating_quantizer_id( other_pq.id ) - self._unified_scale_group_manager.merge_groups(gid, other_gid) + if other_gid is not None: + self._unified_scale_group_manager.merge_groups(gid, other_gid) else: self._unified_scale_group_manager.add_to_group(gid, other_pq) @@ -351,8 +359,8 @@ def _get_major_unified_scale_type(type_list: List[Optional[UnifiedScaleType]]) - def merge_quantizers_for_branching_node( self, quantizers_to_merge: List[PropagatingQuantizer], - merged_qconf_list: List[QuantizerConfig], - branch_qconf_lists: List[Optional[List[QuantizerConfig]]], + merged_qconf_list: Optional[List[QuantizerConfig]], + branch_qconf_lists: List[List[QuantizerConfig]], branching_node_key: str, ) -> List[PropagatingQuantizer]: # A branching node may currently be either a post-hook node, or an operator node if the @@ -363,10 +371,10 @@ def merge_quantizers_for_branching_node( if self.is_insertion_point(branching_node_type): target_ip_node_keys.append(branching_node_key) elif branching_node_type == QuantizerPropagationStateGraphNodeType.OPERATOR: - paths = self.get_paths_to_immediately_dominating_insertion_points(branching_node_key) - for path in paths: - assert len(path) == 1 - edge_from_pre_hook_ip_to_op = path[0] + prop_paths = self.get_paths_to_immediately_dominating_insertion_points(branching_node_key) + for prop_path in prop_paths: + assert len(prop_path) == 1 + edge_from_pre_hook_ip_to_op = prop_path[0] pre_hook_ip = edge_from_pre_hook_ip_to_op[0] target_ip_node_keys.append(pre_hook_ip) else: @@ -424,11 +432,13 @@ def merge_quantizers_for_branching_node( merge_pqs.append(merge_pq) - unified_scale_gids_to_merge = set() + unified_scale_gids_to_merge: Set[int] = set() for idx, pq in enumerate(quantizers_to_merge): branch_qconf_list = branch_qconf_lists[idx] if branch_qconf_list is None and pq.unified_scale_type is not None: gid = self._unified_scale_group_manager.get_group_id_by_propagating_quantizer_id(pq.id) + if gid is None: + raise nncf.InternalError("gid is None") unified_scale_gids_to_merge.add(gid) if unified_scale_gids_to_merge: @@ -439,11 +449,13 @@ def merge_quantizers_for_branching_node( for idx, pq in enumerate(quantizers_to_merge): branch_qconf_list = branch_qconf_lists[idx] if branch_qconf_list is None: - paths = list(nx.all_shortest_paths(self, branching_node_key, pq.current_location_node_key)) + paths: List[List[str]] = list( + nx.all_shortest_paths(self, branching_node_key, pq.current_location_node_key) + ) assert len(paths) == 1, "Ambiguous merge path!" # merge_quantizer_into_path expects paths as lists of edges path = paths[0] - edge_path = [] + edge_path: List[Tuple[str, str]] = [] for i in range(len(path) - 1): from_node_key = path[i] to_node_key = path[i + 1] @@ -517,7 +529,7 @@ def backtrack_propagation_until_accepting_location( from_node_key, to_node_key = prop_quantizer.propagation_path.pop() edge = self.edges[from_node_key, to_node_key] - edge[QuantizerPropagationStateGraph.AFFECTING_PROPAGATING_QUANTIZERS_ATTR].remove(prop_quantizer) + edge[QuantizerPropagationStateGraph.AFFECTING_PROPAGATING_QUANTIZERS_ATTR].remove(prop_quantizer) # type: ignore prop_quantizer.affected_edges.remove((from_node_key, to_node_key)) from_node = self.nodes[from_node_key] from_node_type = from_node[QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR] @@ -540,7 +552,7 @@ def unify_pq_scales( primary_pq: PropagatingQuantizer, secondary_pq: PropagatingQuantizer, unified_scale_type: Optional[UnifiedScaleType] = None, - ): + ) -> None: if unified_scale_type is None: primary_pq.unified_scale_type = UnifiedScaleType.UNIFY_ALWAYS else: @@ -591,7 +603,7 @@ def add_propagating_quantizer( prop_quantizer.quantized_input_sink_operator_nodes.add(affected_op_node_key) return prop_quantizer - def _verify_nodes_and_edges_for_pq(self, prop_quantizer: PropagatingQuantizer): + def _verify_nodes_and_edges_for_pq(self, prop_quantizer: PropagatingQuantizer) -> None: node_keys_to_verify = ( list(prop_quantizer.affected_operator_nodes) + list(prop_quantizer.quantized_input_sink_operator_nodes) @@ -616,7 +628,7 @@ def _verify_nodes_and_edges_for_pq(self, prop_quantizer: PropagatingQuantizer): @staticmethod def _verify_qconfig_matching( prop_quantizer: PropagatingQuantizer, existing_prop_quantizers: List[PropagatingQuantizer] - ): + ) -> None: for existing_pq in existing_prop_quantizers: if existing_pq.potential_quant_configs != prop_quantizer.potential_quant_configs: raise nncf.InternalError( @@ -624,7 +636,7 @@ def _verify_qconfig_matching( "existing quantizer {}".format(existing_pq.id) ) - def register_propagating_quantizer(self, prop_quantizer: PropagatingQuantizer): + def register_propagating_quantizer(self, prop_quantizer: PropagatingQuantizer) -> None: """Will only succeed if the new quantizer information is consistent with the rest of the graph state.""" all_pqs = self.collect_all_propagating_quantizers() for existing_pq_id in all_pqs: @@ -688,13 +700,14 @@ def clone_propagating_quantizer(self, prop_quantizer: PropagatingQuantizer) -> P if cloned_prop_quant.unified_scale_type is not None: gid = self._unified_scale_group_manager.get_group_id_by_propagating_quantizer_id(prop_quantizer.id) + assert gid is not None self._unified_scale_group_manager.add_to_group(gid, cloned_prop_quant) return cloned_prop_quant def remove_propagating_quantizer( - self, prop_quantizer: PropagatingQuantizer, keep_propagating_quantizer_at_current_node=False - ): + self, prop_quantizer: PropagatingQuantizer, keep_propagating_quantizer_at_current_node: bool = False + ) -> None: for edge_tuple in prop_quantizer.affected_edges: edge = self.edges[edge_tuple] affecting_quantizers = edge[QuantizerPropagationStateGraph.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] @@ -718,6 +731,7 @@ def remove_propagating_quantizer( prop_quantizer.affected_edges.clear() if prop_quantizer.unified_scale_type is not None: gid = self._unified_scale_group_manager.get_group_id_by_propagating_quantizer_id(prop_quantizer.id) + assert gid is not None self._unified_scale_group_manager.remove_from_group(gid, prop_quantizer) self._pqs_after_weight_dependent_output_quantized_nodes.pop(prop_quantizer, None) @@ -752,11 +766,11 @@ def propagate_quantizer_via_path( target_node[QuantizerPropagationStateGraph.PROPAGATING_QUANTIZER_NODE_ATTR] = prop_quantizer return prop_quantizer - def get_non_quant_agnostic_op_nodes_immediately_dominated_by_node(self, node_key) -> List[str]: - ret_node_key_list = [] + def get_non_quant_agnostic_op_nodes_immediately_dominated_by_node(self, node_key: str) -> List[str]: + ret_node_key_list: List[str] = [] - def recursive_helper(curr_node_key: str, target_node_list: List[str]): - successors = self.successors(curr_node_key) + def recursive_helper(curr_node_key: str, target_node_list: List[str]) -> None: + successors = cast(List[str], self.successors(curr_node_key)) for successor_key in successors: successor = self.nodes[successor_key] successor_node_type = successor[QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR] @@ -770,14 +784,14 @@ def recursive_helper(curr_node_key: str, target_node_list: List[str]): recursive_helper(node_key, ret_node_key_list) return ret_node_key_list - def all_outputs_are_quantized(self, node_key) -> bool: + def all_outputs_are_quantized(self, node_key: str) -> bool: """ - Returns True if all pathes from the given node to the first - input quantable nodes have an activation quantizer, False otherwise. + Returns True if all paths from the given node to the first + input quantizable nodes have an activation quantizer, False otherwise. :param node_key: Given node key. - :return: True if all pathes from the given node to the first - input quantable nodes have an activation quantizer, False otherwise. + :return: True if all paths from the given node to the first + input quantizable nodes have an activation quantizer, False otherwise. """ nodes_keys_stack = deque(self.successors(node_key)) @@ -811,13 +825,13 @@ def get_paths_to_immediately_dominating_insertion_points_grouped_by_unified_scal self, insertion_point_node_key: str, unified_scale_op_metatypes: Set[Type[OperatorMetatype]], - scales_unification_map: Dict[OperatorMetatype, OperatorMetatype], + scales_unification_map: Dict[Type[OperatorMetatype], List[Type[OperatorMetatype]]], ) -> Dict[Optional[int], List[PropagationPath]]: """Paths are lists of edges.""" next_group_idx = 0 - paths = {} + paths: Dict[Union[int, None], List[List[Tuple[str, str]]]] = {} - def followed_by_weighted_types(curr_node_key, curr_node_metatype) -> bool: + def followed_by_weighted_types(curr_node_key: str, curr_node_metatype: type[OperatorMetatype]) -> bool: nodes_queue = deque(self.successors(curr_node_key)) while nodes_queue: next_node_key = nodes_queue.popleft() @@ -837,7 +851,12 @@ def followed_by_weighted_types(curr_node_key, curr_node_metatype) -> bool: return True return False - def recursive_helper(curr_edge, curr_path, all_paths, curr_group): + def recursive_helper( + curr_edge: Tuple[str, str], + curr_path: List[Tuple[str, str]], + all_paths: Dict[Union[int, None], List[List[Tuple[str, str]]]], + curr_group: Optional[int], + ) -> None: nonlocal next_group_idx curr_path.append(curr_edge) curr_node_key = curr_edge[0] @@ -927,8 +946,8 @@ def traverse_fn(curr_node_key: str, local_state: LocalState) -> Tuple[bool, Loca local_state.encountered_quantizer_aware_ops = True return False, local_state - visited_node_keys = set() - result = set() + visited_node_keys: set[str] = set() + result: Set[str] = set() for output_node_key in self._output_node_keys_vs_nncf_nodes: output_state = LocalState(result) self._traverse_graph_recursive_helper( @@ -945,7 +964,7 @@ def is_branching_node_dominating_outputs(self, from_node_key: str) -> bool: self._branch_nodes_directly_dominating_outputs = self._build_branch_direct_output_dominators_info() return from_node_key in self._branch_nodes_directly_dominating_outputs - def get_visualized_graph(self): + def get_visualized_graph(self) -> nx.DiGraph: out_graph = nx.DiGraph() unified_scale_group_vs_pq_node_id_dict: Dict[int, List[str]] = {} for node_key, node in self.nodes.items(): @@ -981,6 +1000,7 @@ def get_visualized_graph(self): gid = self._unified_scale_group_manager.get_group_id_by_propagating_quantizer_id( prop_quantizer.id ) + assert gid is not None if gid in unified_scale_group_vs_pq_node_id_dict: unified_scale_group_vs_pq_node_id_dict[gid].append(quant_node_key) else: @@ -1067,7 +1087,7 @@ def _traverse_graph_recursive_helper( output: Any, traverse_backward: bool = False, visit_once: bool = True, - ): + ) -> Any: """This is DFS, and may fail with 'maximum recursion depth exceeded' for complex graphs.""" is_finished, output = traverse_function(curr_node_key, output) if visit_once: @@ -1082,26 +1102,26 @@ def _traverse_graph_recursive_helper( ) return output - def _get_next_prop_quantizer_id(self): + def _get_next_prop_quantizer_id(self) -> int: self._created_prop_quantizer_counter += 1 return self._created_prop_quantizer_counter - def _is_position_accepting(self, ip_node_key: str): + def _is_position_accepting(self, ip_node_key: str) -> bool: return True - def get_unified_scale_group_id_by_propagating_quantizer_id(self, pqid: int) -> int: + def get_unified_scale_group_id_by_propagating_quantizer_id(self, pqid: int) -> Optional[int]: return self._unified_scale_group_manager.get_group_id_by_propagating_quantizer_id(pqid) def get_quantizers_at_input_nncf_nodes(self) -> Dict[NNCFNode, List[int]]: retval: Dict[NNCFNode, List[int]] = {} - def recursive_helper(curr_node_key: str, curr_input_quantizer_ids_list: List[int]): + def recursive_helper(curr_node_key: str, curr_input_quantizer_ids_list: List[int]) -> None: curr_node = self.nodes[curr_node_key] curr_node_type = curr_node[QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR] if self.is_insertion_point(curr_node_type): pq = curr_node[QuantizerPropagationStateGraph.PROPAGATING_QUANTIZER_NODE_ATTR] - if pq is not None: + if isinstance(pq, PropagatingQuantizer): curr_input_quantizer_ids_list.append(pq.id) return elif curr_node_type == QuantizerPropagationStateGraphNodeType.OPERATOR: @@ -1115,16 +1135,16 @@ def recursive_helper(curr_node_key: str, curr_input_quantizer_ids_list: List[int recursive_helper(successor_key, curr_input_quantizer_ids_list) for input_node_key, input_nncf_node in self._input_node_keys_vs_nncf_nodes.items(): - current_input_quantizer_ids = [] + current_input_quantizer_ids: List[int] = [] recursive_helper(input_node_key, current_input_quantizer_ids) retval[input_nncf_node] = current_input_quantizer_ids return retval - def merge_redundant_subsequent_quantizers_across_graph(self): + def merge_redundant_subsequent_quantizers_across_graph(self) -> None: def is_downstream_quantizer_redundant( downstream_quantizer: PropagatingQuantizer, upstream_quantizer: PropagatingQuantizer - ): + ) -> bool: ds_configs = downstream_quantizer.potential_quant_configs us_configs = upstream_quantizer.potential_quant_configs assert len(ds_configs) == 1 @@ -1149,7 +1169,7 @@ def is_downstream_quantizer_redundant( def merge_traverse_fn( curr_node_key: str, affecting_pq_and_prev_node_key: Tuple[Optional[PropagatingQuantizer], str] - ) -> Tuple[Optional[PropagatingQuantizer], str]: + ) -> Tuple[bool, Tuple[Optional[PropagatingQuantizer], str]]: # For this to work, DFS must be used for graph traversal. Also, this only # works with the generic traverse_graph interface because of # Python's pass-by-value mechanism for tuples. @@ -1221,7 +1241,7 @@ def get_quant_insertion_point_for_propagating_quantizer( final_node_key = prop_quant.current_location_node_key final_node = self.nodes[final_node_key] insertion_point = final_node[QuantizerPropagationStateGraph.QUANT_INSERTION_POINT_DATA_NODE_ATTR] - return insertion_point + return cast(QuantizationInsertionPointBase, insertion_point) def _get_all_quantizers_grouped_by_affecting_op_set(self) -> List[SharedAffectedOpsPropagatingQuantizerGroup]: all_pqs = self.collect_all_propagating_quantizers() @@ -1234,20 +1254,20 @@ class Grouper: scenario) will be placed in a separate group. """ - def __init__(self): + def __init__(self) -> None: self._group_vs_node_keys_and_pqs: Dict[int, SharedAffectedOpsPropagatingQuantizerGroup] = {} self._next_gid = 0 - def _get_next_gid(self): + def _get_next_gid(self) -> int: curr_gid = self._next_gid self._next_gid += 1 return curr_gid - def _merge_groups(self, gid_to: int, gid_from: int): + def _merge_groups(self, gid_to: int, gid_from: int) -> None: self._group_vs_node_keys_and_pqs[gid_to].update(self._group_vs_node_keys_and_pqs[gid_from]) self._group_vs_node_keys_and_pqs.pop(gid_from) - def add_pq(self, pq: PropagatingQuantizer): + def add_pq(self, pq: PropagatingQuantizer) -> None: new_gid = self._get_next_gid() self._group_vs_node_keys_and_pqs[new_gid] = SharedAffectedOpsPropagatingQuantizerGroup( {pq}, set(pq.quantized_input_sink_operator_nodes) @@ -1480,10 +1500,10 @@ def _get_weight_and_activation_qconfig_list_intersection( act_qconfig_extend_list += activation_qconfig_options return [qconf for qconf in weight_qconfig_options if qconf in act_qconfig_extend_list] - def run_consistency_check(self) -> bool: + def run_consistency_check(self) -> None: all_pqs = self.collect_all_propagating_quantizers() - def traverse_fn(curr_node_key: str, unused) -> Tuple[bool, Any]: + def traverse_fn(curr_node_key: str, _: Any) -> Tuple[bool, Any]: nncf_logger.debug(f"Processing node: {curr_node_key}") node = self.nodes[curr_node_key] node_type = node[QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR] diff --git a/nncf/common/quantization/quantizer_propagation/grouping.py b/nncf/common/quantization/quantizer_propagation/grouping.py index d6d63533db8..64f30fbb5d3 100644 --- a/nncf/common/quantization/quantizer_propagation/grouping.py +++ b/nncf/common/quantization/quantizer_propagation/grouping.py @@ -21,7 +21,7 @@ class UnifiedScalePropagatingQuantizerGroupManager: quantized model. """ - def __init__(self): + def __init__(self) -> None: self._next_gid = 0 self._group_vs_prop_quants_dict: Dict[int, Set[PropagatingQuantizer]] = {} @@ -46,7 +46,7 @@ def register_group(self, prop_quants: Set[PropagatingQuantizer]) -> int: self._group_vs_prop_quants_dict[gid] = prop_quants return gid - def add_to_group(self, target_gid: int, prop_quant: PropagatingQuantizer): + def add_to_group(self, target_gid: int, prop_quant: PropagatingQuantizer) -> None: """ Adds a propagating quantizer to an already existing group. @@ -62,7 +62,7 @@ def add_to_group(self, target_gid: int, prop_quant: PropagatingQuantizer): ) self._group_vs_prop_quants_dict[target_gid].add(prop_quant) - def remove_from_group(self, group: int, prop_quant: PropagatingQuantizer): + def remove_from_group(self, group: int, prop_quant: PropagatingQuantizer) -> None: """ Removes a propagating quantizer from a group. @@ -91,7 +91,7 @@ def get_group_id_by_propagating_quantizer_id(self, requested_pqid: int) -> Optio return gid return None - def merge_groups(self, merge_to_gid: int, merge_from_gid: int): + def merge_groups(self, merge_to_gid: int, merge_from_gid: int) -> None: """ Merges two groups into a single one. The `merge_to_gid` group retains its group ID. @@ -110,11 +110,11 @@ class QuantizersWaitingForMergeManager: and corresponding node keys. """ - def __init__(self): + def __init__(self) -> None: self._branching_node_keys_vs_quantizers_waiting_for_merge: Dict[str, Set[PropagatingQuantizer]] = {} self._quantizers_vs_branching_node_keys: Dict[PropagatingQuantizer, str] = {} - def add_propagating_quantizer_to_wait_on_node_key(self, pq: PropagatingQuantizer, branching_node_key: str): + def add_propagating_quantizer_to_wait_on_node_key(self, pq: PropagatingQuantizer, branching_node_key: str) -> None: """ Registers a propagating quantizer as "waiting" on a node in QuantizerPropagationStateGraph. @@ -146,10 +146,10 @@ def get_waiting_quantizers_for_branching_node_key(self, node_key: str) -> Set[Pr """ return self._branching_node_keys_vs_quantizers_waiting_for_merge[node_key] - def __contains__(self, item: PropagatingQuantizer): + def __contains__(self, item: PropagatingQuantizer) -> bool: return item in self._quantizers_vs_branching_node_keys - def resolve_merged_node(self, branching_node_key: str): + def resolve_merged_node(self, branching_node_key: str) -> None: """ De-registers any quantizers that were previously registered to be "waiting" on a given node key. :param branching_node_key: The node key in QuantizerPropagationStateGraph that some propagating diff --git a/nncf/common/quantization/quantizer_propagation/solver.py b/nncf/common/quantization/quantizer_propagation/solver.py index c2861841e6a..604e46fefcd 100644 --- a/nncf/common/quantization/quantizer_propagation/solver.py +++ b/nncf/common/quantization/quantizer_propagation/solver.py @@ -15,9 +15,9 @@ from collections import deque from copy import deepcopy from enum import Enum -from typing import Deque, Dict, List, Optional, Set, Tuple +from typing import Any, Deque, Dict, List, Optional, Set, Tuple, Type, cast -import networkx as nx +import networkx as nx # type: ignore import nncf from nncf.common.graph import NNCFNodeName @@ -80,7 +80,7 @@ def __init__( self._quant_prop_graph = quant_prop_graph @property - def quant_prop_graph(self): + def quant_prop_graph(self) -> QuantizerPropagationStateGraph: return self._quant_prop_graph @@ -114,7 +114,7 @@ def __init__( def constrain_quantizer_config_list_for_insertion( self, quantization_point_id: QuantizationPointId, constrained_config_list: List[QuantizerConfig] - ): + ) -> None: """ Constrains a set of available quantizer configurations for a quantization point with a given ID as defined by the list of quantizer configurations - in essence, performs a selection. @@ -136,7 +136,9 @@ def constrain_quantizer_config_list_for_insertion( pq = self._quantization_point_id_vs_prop_quantizer[quantization_point_id] pq.potential_quant_configs = constrained_config_list - def finalize(self, final_quantizer_setup: SingleConfigQuantizerSetup, strict=True) -> FinalizedQuantizationProposal: + def finalize( + self, final_quantizer_setup: SingleConfigQuantizerSetup, strict: bool = True + ) -> FinalizedQuantizationProposal: """ Given a single-configuration quantizer setup (which is constructed by picking a single quantizer configuration for each of the multi-configuration quantization points in this proposal's multi-config setup), prepares a @@ -154,7 +156,7 @@ def finalize(self, final_quantizer_setup: SingleConfigQuantizerSetup, strict=Tru final_qconfig = final_quantizer_setup.quantization_points[qp_id].qconfig if strict: - def is_final_qconfig_compatible_to_initial(initial_qconfig: QuantizerConfig): + def is_final_qconfig_compatible_to_initial(initial_qconfig: QuantizerConfig) -> bool: return ( final_qconfig.per_channel == initial_qconfig.per_channel and final_qconfig.mode == initial_qconfig.mode @@ -198,7 +200,7 @@ def __init__( self, quant_prop_graph: QuantizerPropagationStateGraph, quantizable_layer_nodes: List[QuantizableWeightedLayerNode], - post_processing_marker_metatypes: List[OperatorMetatype], + post_processing_marker_metatypes: List[Type[OperatorMetatype]], ): self._quant_prop_graph = quant_prop_graph self._post_processing_marker_metatypes = post_processing_marker_metatypes @@ -214,13 +216,16 @@ def _is_node_has_underlying_weights(self, node_key: str) -> bool: return True return False - def _get_node_metatype(self, node_key: str) -> OperatorMetatype: + def _get_node_metatype(self, node_key: str) -> Type[OperatorMetatype]: node = self._quant_prop_graph.nodes[node_key] - return node.get(self._quant_prop_graph.OPERATOR_METATYPE_NODE_ATTR) + return cast(Type[OperatorMetatype], node.get(self._quant_prop_graph.OPERATOR_METATYPE_NODE_ATTR)) def _is_node_operator(self, node_key: str) -> bool: - node = self._quant_prop_graph.nodes[node_key] - return node.get(self._quant_prop_graph.NODE_TYPE_NODE_ATTR) == QuantizerPropagationStateGraphNodeType.OPERATOR + node = cast(Dict[str, Any], self._quant_prop_graph.nodes[node_key]) + node_type = cast( + Optional[QuantizerPropagationStateGraphNodeType], node.get(self._quant_prop_graph.NODE_TYPE_NODE_ATTR) + ) + return node_type == QuantizerPropagationStateGraphNodeType.OPERATOR def get_post_processing_node_keys(self) -> Set[str]: """ @@ -236,11 +241,11 @@ def get_post_processing_node_keys(self) -> Set[str]: for output_metatype in OUTPUT_NOOP_METATYPES.values(): output_nodes.extend(self._quant_prop_graph.get_node_keys_by_metatype(output_metatype)) - def get_ignored_operations(output_nodes: List[str]) -> Tuple[Set[str], Set[str]]: + def get_ignored_operations(output_nodes: List[str]) -> Set[str]: stack = [([start_node_key], False) for start_node_key in output_nodes] - ignored_operations = set() + ignored_operations: Set[str] = set() - def _extend_ignored_operations(path: List[str]): + def _extend_ignored_operations(path: List[str]) -> None: for node in path: if ( self._is_node_operator(node) @@ -303,23 +308,23 @@ class QuantizerPropagationSolver: def __init__( self, - activation_ignored_scopes: Dict[str, IgnoreReason] = None, - weight_ignored_scopes: List[str] = None, - activation_target_scopes: List[str] = None, - weight_target_scopes: List[str] = None, - hw_config: HWConfig = None, - default_trait_to_metatype_map: Dict[QuantizationTrait, List[OperatorMetatype]] = None, - propagation_strategy: QuantizerPropagationRule = None, - default_qconfig_list: List[QuantizerConfig] = None, - quantizable_layer_nodes: List[QuantizableWeightedLayerNode] = None, - scope_overrides: Dict = None, - global_constraints: Dict[QuantizerGroup, QuantizationConstraints] = None, - additional_unified_scale_op_scopes: List[List[str]] = None, + activation_ignored_scopes: Optional[Dict[str, IgnoreReason]] = None, + weight_ignored_scopes: Optional[List[str]] = None, + activation_target_scopes: Optional[List[str]] = None, + weight_target_scopes: Optional[List[str]] = None, + hw_config: Optional[HWConfig] = None, + default_trait_to_metatype_map: Optional[Dict[QuantizationTrait, List[Type[OperatorMetatype]]]] = None, + propagation_strategy: Optional[QuantizerPropagationRule] = None, + default_qconfig_list: Optional[List[QuantizerConfig]] = None, + quantizable_layer_nodes: Optional[List[QuantizableWeightedLayerNode]] = None, + scope_overrides: Optional[Dict[str, Any]] = None, + global_constraints: Optional[Dict[QuantizerGroup, QuantizationConstraints]] = None, + additional_unified_scale_op_scopes: Optional[List[List[str]]] = None, run_consistency_checks: bool = False, quantize_outputs: bool = False, - post_processing_marker_metatypes: List[OperatorMetatype] = None, - metatypes_to_ignore: List[OperatorMetatype] = None, - scales_unification_map: Dict[OperatorMetatype, OperatorMetatype] = None, + post_processing_marker_metatypes: Optional[List[Type[OperatorMetatype]]] = None, + metatypes_to_ignore: Optional[List[Type[OperatorMetatype]]] = None, + scales_unification_map: Optional[Dict[Type[OperatorMetatype], List[Type[OperatorMetatype]]]] = None, ): """ Initializes the solver with parameters affecting the resulting quantizer setup. @@ -379,8 +384,8 @@ def __init__( else: self._default_trait_to_metatype_map = default_trait_to_metatype_map self.default_global_qconfig_list = default_qconfig_list - self._hw_config: HWConfig = hw_config - self._visualizer = None + self._hw_config = hw_config + self._visualizer: Optional[Any] = None if is_debug(): from nncf.common.quantization.quantizer_propagation.visualizer import QuantizerPropagationVisualizer @@ -398,10 +403,10 @@ def __init__( ) if scope_overrides is None: - self._scope_overrides = {} + self._scope_overrides: Dict[str, Any] = {} else: - self._scope_overrides: Dict = scope_overrides - self._global_constraints: Dict["QuantizerGroup", "QuantizationConstraints"] = global_constraints + self._scope_overrides = scope_overrides + self._global_constraints = global_constraints self._run_consistency_checks = run_consistency_checks self._unified_scales_operation_set = set() @@ -419,22 +424,21 @@ def __init__( and HWConfig.is_qconf_list_corresponding_to_unspecified_op(qconf_list) ): self._operator_allowed_qconfigs_map[op_meta] = default_qconfig_list - self._active_propagating_quantizers_queue = deque() + self._active_propagating_quantizers_queue: Deque[PropagatingQuantizer] = deque() self._finished_propagating_quantizers: List[PropagatingQuantizer] = [] self._quantizers_waiting_for_branch_merge = QuantizersWaitingForMergeManager() - self._potential_quantizers = {} self._num_potential_quantized_activations = 0 - self._quantizable_layer_nodes = quantizable_layer_nodes + self._quantizable_layer_nodes = quantizable_layer_nodes if quantizable_layer_nodes is not None else [] self._post_processing_marker_metatypes = post_processing_marker_metatypes self._metatypes_to_ignore = metatypes_to_ignore - self._scales_unification_map = scales_unification_map + self._scales_unification_map = scales_unification_map if scales_unification_map is not None else {} def _filter_by_weight_ignored_target_scopes( self, - quantizable_layer_nodes: List[QuantizableWeightedLayerNode], - weight_ignored_scopes: List[str], - weight_target_scopes: List[str], + quantizable_layer_nodes: Optional[List[QuantizableWeightedLayerNode]], + weight_ignored_scopes: Optional[List[str]], + weight_target_scopes: Optional[List[str]], ) -> Dict[NNCFNodeName, List[QuantizerConfig]]: if quantizable_layer_nodes is None: return {} @@ -595,11 +599,11 @@ def _handle_quantizer_merge( waiting_pqs: Set[PropagatingQuantizer], quant_prop_graph: QuantizerPropagationStateGraph, branching_node_key: str, - ): + ) -> None: waiting_pqs_list = list(waiting_pqs) - merged_pqs = [] - unmerged_pqs = [] - abort_merge = False + merged_pqs: List[PropagatingQuantizer] = [] + unmerged_pqs: List[PropagatingQuantizer] = [] + abort_merge: bool = False for pq in waiting_pqs_list: # While the quantizers were waiting for the merge, one of the concat nodes # that will be affected by the merge may have been determined to be unquantizable. @@ -704,9 +708,8 @@ def propagation_step( self._finished_propagating_quantizers.append(prop_quantizer) return quant_prop_graph - surviving_prop_quantizers = [] - - prop_quantizers_to_process = [] + surviving_prop_quantizers: List[PropagatingQuantizer] = [] + prop_quantizers_to_process: List[PropagatingQuantizer] = [] did_clone = False # TODO (vshampor): include information on unified scale type in grouping; for now assuming that @@ -717,7 +720,9 @@ def propagation_step( ) ) - unified_scale_path_groups_vs_pqs = {k: [] for k in unified_scale_grouped_paths if k is not None} + unified_scale_path_groups_vs_pqs: Dict[int, List[PropagatingQuantizer]] = { + k: [] for k in unified_scale_grouped_paths if k is not None + } existing_pq_assigned = False for gid, path_group in unified_scale_grouped_paths.items(): for _ in path_group: @@ -743,6 +748,7 @@ def propagation_step( pqs_and_paths = zip(paths, prop_quantizers_to_process) for path, prop_quantizer in pqs_and_paths: + assert prop_quantizer is not None status = self.check_transition_via_path(prop_quantizer, path, quant_prop_graph, cloned_prop_quantizers) if status == TransitionStatus.SHOULD_NOT_TRANSITION: if did_clone and prop_quantizer is not curr_prop_quantizer: @@ -764,7 +770,7 @@ def propagation_step( quant_prop_graph.merge_quantizer_into_path(prop_quantizer, path) elif status == TransitionStatus.SHOULD_WAIT_FOR_MERGE: - branching_node_key = None + branching_node_key: Optional[str] = None # type: ignore[no-redef] for from_node_key, _ in path: if len(list(quant_prop_graph.successors(from_node_key))) > 1: branching_node_key = path[0][0] @@ -779,7 +785,9 @@ def propagation_step( self._active_propagating_quantizers_queue.appendleft(prop_quantizer) return quant_prop_graph - def get_allowed_quantizer_configs_for_operator(self, quant_det_id: OperatorMetatype) -> List[QuantizerConfig]: + def get_allowed_quantizer_configs_for_operator( + self, quant_det_id: Type[OperatorMetatype] + ) -> Optional[List[QuantizerConfig]]: """ Returns the quantizer configurations that were determined as allowed for a given metatype by HW config or other means. @@ -820,13 +828,13 @@ def set_allowed_quantization_types_for_operator_nodes( ) return quant_prop_graph - def get_operator_quantization_traits_map(self) -> Dict[OperatorMetatype, QuantizationTrait]: + def get_operator_quantization_traits_map(self) -> Dict[Type[OperatorMetatype], QuantizationTrait]: """ :return: A mapping of operator metatypes to the quantization traits to be assigned to such operations. """ # TODO (vshampor): ensure that there are no name collisions between ops in different torch subpackages with # the same name - retval = {} + retval: Dict[Type[OperatorMetatype], QuantizationTrait] = {} if self._hw_config is None: for trait, meta_list in self._default_trait_to_metatype_map.items(): for op_meta in meta_list: @@ -848,7 +856,7 @@ def get_operator_quantization_traits_map(self) -> Dict[OperatorMetatype, Quantiz retval[op_meta] = trait return retval - def _get_trait_for_op_meta_not_specified_in_hw_config(self, op_meta: OperatorMetatype) -> QuantizationTrait: + def _get_trait_for_op_meta_not_specified_in_hw_config(self, op_meta: Type[OperatorMetatype]) -> QuantizationTrait: if not op_meta.hw_config_names: # The metatype might not have an associated name in the config # namespace (yet) - use default trait @@ -870,10 +878,12 @@ def _get_trait_for_op_meta_not_specified_in_hw_config(self, op_meta: OperatorMet return trait - def _get_operator_qconfigs_map(self) -> Dict[OperatorMetatype, List[QuantizerConfig]]: + def _get_operator_qconfigs_map(self) -> Dict[Type[OperatorMetatype], Optional[List[QuantizerConfig]]]: # TODO (vshampor): ensure that there are no name collisions between ops in different torch subpackages # with the same name - retval = {} # Metas not in retval will correspond to wildcard quantization + retval: Dict[Type[OperatorMetatype], Optional[List[QuantizerConfig]]] = ( + {} + ) # Metas not in retval will correspond to wildcard quantization if self._hw_config is None: for trait, meta_list in self._default_trait_to_metatype_map.items(): if trait == QuantizationTrait.INPUTS_QUANTIZABLE: @@ -889,7 +899,7 @@ def _get_operator_qconfigs_map(self) -> Dict[OperatorMetatype, List[QuantizerCon retval = self._hw_config.get_metatype_vs_quantizer_configs_map() return retval - def debug_visualize(self, quant_prop_graph: QuantizerPropagationStateGraph, dump_path: str): + def debug_visualize(self, quant_prop_graph: QuantizerPropagationStateGraph, dump_path: str) -> None: """ Visualizes in a .dot format the state of the current quantizer propagation state graph and the associated solver information. @@ -969,7 +979,7 @@ def setup_initial_quantizers( @staticmethod def coalesce_insertion_points( - target_insertion_points: List[TargetPoint], linked_scopes_groups_list: List[List[str]] + target_insertion_points: List[TargetPoint], linked_scopes_groups_list: Optional[List[List[str]]] ) -> List[List[TargetPoint]]: """ Accepts a list of TargetPoints and groups these according to linked_scope_groups_list. @@ -985,20 +995,15 @@ def coalesce_insertion_points( """ if linked_scopes_groups_list is None: - return [ - [ - ip, - ] - for ip in target_insertion_points - ] - retval = [] - insertion_point_indices_vs_group_id = OrderedDict() + return [[ip] for ip in target_insertion_points] + retval: List[List[TargetPoint]] = [] + insertion_point_indices_vs_group_id: Dict[int, Optional[int]] = OrderedDict() for group_idx, group_list in enumerate(linked_scopes_groups_list): for group_member_node_name in group_list: matching_indices = list( filter( - lambda x: target_insertion_points[x].target_node_name == group_member_node_name, + lambda x: target_insertion_points[x].target_node_name == group_member_node_name, # type: ignore range(len(target_insertion_points)), ) ) @@ -1023,27 +1028,21 @@ def coalesce_insertion_points( insertion_point_indices_vs_group_id[i] = None group_indices_list: List[List[int]] = [[] for _ in linked_scopes_groups_list] - for insertion_point_idx, group_idx in insertion_point_indices_vs_group_id.items(): + for insertion_point_idx, group_idx in insertion_point_indices_vs_group_id.items(): # type: ignore[assignment] if group_idx is not None: group_indices_list[group_idx].append(insertion_point_idx) for intra_group_indices in group_indices_list: main_ip_idx = intra_group_indices[0] main_ip = target_insertion_points[main_ip_idx] - grouped_list = [ - main_ip, - ] + grouped_list = [main_ip] for linked_ip_idx in intra_group_indices[1:]: grouped_list.append(target_insertion_points[linked_ip_idx]) retval.append(grouped_list) - for insertion_point_idx, group_idx in insertion_point_indices_vs_group_id.items(): + for insertion_point_idx, group_idx in insertion_point_indices_vs_group_id.items(): # type: ignore[assignment] if group_idx is None: - retval.append( - [ - target_insertion_points[insertion_point_idx], - ] - ) + retval.append([target_insertion_points[insertion_point_idx]]) return retval @@ -1072,7 +1071,7 @@ def _filter_qconfigs_according_to_scope( def _setup_initial_quantizers_for_operator_node( self, operator_node_key: str, quant_prop_graph: QuantizerPropagationStateGraph - ): + ) -> None: node = quant_prop_graph.nodes[operator_node_key] # preds are in sorted order for reproducibility @@ -1258,7 +1257,7 @@ def check_branching_transition( def _check_affecting_quantizers_in_common_path( self, affecting_quantizers: List[PropagatingQuantizer], cloned_prop_quantizers: List[PropagatingQuantizer] - ): + ) -> None: # Handling the case where multiple freshly cloned quantizers have to follow paths that are different, # but have a common edge or node safe_affecting_quantizers = [pq for pq in affecting_quantizers if pq in cloned_prop_quantizers] @@ -1368,8 +1367,8 @@ def check_transition_via_path( return TransitionStatus.SHOULD_TRANSITION def get_merged_qconfigs_for_downward_branching_case( - self, potential_qconfigs_for_each_branch: List[List[Optional[QuantizerConfig]]] - ) -> Tuple[Optional[List[QuantizerConfig]], List[Optional[List[QuantizerConfig]]]]: + self, potential_qconfigs_for_each_branch: List[List[QuantizerConfig]] + ) -> Tuple[Optional[List[QuantizerConfig]], List[List[QuantizerConfig]]]: """ Returns a tuple, of which the first node is the qconfig list for the quantizer to be placed above the branching node (i.e. that will affect all of the downward branches), and a list @@ -1396,13 +1395,14 @@ def get_merged_qconfigs_for_downward_branching_case( if first_pq_list_counter != Counter(other_pq_list): return None, potential_qconfigs_for_each_branch - return first_pq_list, [None for _ in potential_qconfigs_for_each_branch] + return first_pq_list, [None for _ in potential_qconfigs_for_each_branch] # type: ignore[misc] # Attempt to produce a merged config options space - qconfigs_union = set() + qconfigs_union: Set[QuantizerConfig] = set() for branch_qconfig_list in potential_qconfigs_for_each_branch: + assert branch_qconfig_list is not None qconfigs_union.update(set(branch_qconfig_list)) - merged_qconfig_list = [] + merged_qconfig_list: List[QuantizerConfig] = [] nncf_logger.debug(f"Union of configs: {';'.join([str(qc) for qc in qconfigs_union])}") @@ -1451,7 +1451,9 @@ def compatible_wo_requant(qconf: QuantizerConfig, other_qconf_list: List[Quantiz nncf_logger.debug(f"Disambiguated merge qconfig list: {';'.join([str(qc) for qc in merged_qconfig_list])}") merged_qconfig_list_counter = Counter(merged_qconfig_list) - resulting_branch_qconfig_lists = [None for _ in potential_qconfigs_for_each_branch] + resulting_branch_qconfig_lists: List[List[QuantizerConfig]] = [ + None for _ in potential_qconfigs_for_each_branch # type: ignore[misc] + ] if self._propagation_strategy == QuantizerPropagationRule.MERGE_WITH_POTENTIAL_REQUANTIZATION: for idx, branch_qconfig_list in enumerate(potential_qconfigs_for_each_branch): @@ -1499,7 +1501,7 @@ class QConfigComparator: def __init__(self, qconfig: QuantizerConfig): self.qconfig = qconfig - def __lt__(self, other: "QConfigComparator"): + def __lt__(self, other: "QConfigComparator") -> bool: # Prefer higher bitwidths, per-tensor, symmetrical if self.qconfig.num_bits > other.qconfig.num_bits: return True @@ -1556,7 +1558,7 @@ def get_active_propagating_quantizers_queue(self) -> Deque[PropagatingQuantizer] """ return self._active_propagating_quantizers_queue - def get_total_quantizer_count(self): + def get_total_quantizer_count(self) -> int: return len(self.get_finished_propagating_quantizers()) + len(self.get_active_propagating_quantizers_queue()) def _filter_integer_input_quantizers( diff --git a/nncf/common/quantization/quantizer_setup.py b/nncf/common/quantization/quantizer_setup.py index ff2f44af96f..8af1c7fbaac 100644 --- a/nncf/common/quantization/quantizer_setup.py +++ b/nncf/common/quantization/quantizer_setup.py @@ -13,7 +13,7 @@ from collections import Counter from copy import deepcopy from enum import Enum -from typing import Any, Dict, List, Optional, Set +from typing import Any, Dict, List, Optional, Set, Union import nncf from nncf.common.graph import NNCFNodeName @@ -444,13 +444,13 @@ def list2set(pair): class MultiConfigQuantizerSetup(QuantizerSetupBase): - def __init__(self): + def __init__(self) -> None: super().__init__() self.quantization_points: Dict[QuantizationPointId, MultiConfigQuantizationPoint] = {} self._unified_scale_qpid_vs_type: Dict[QuantizationPointId, UnifiedScaleType] = {} def register_unified_scale_group_with_types( - self, qp_group: List[QuantizationPointId], us_types: List[UnifiedScaleType] + self, qp_group: List[QuantizationPointId], us_types: List[Union[UnifiedScaleType, None]] ) -> int: assert len(qp_group) == len(us_types) gid = super().register_unified_scale_group(qp_group) diff --git a/nncf/common/quantization/structs.py b/nncf/common/quantization/structs.py index 507f0779653..3a7cdf04728 100644 --- a/nncf/common/quantization/structs.py +++ b/nncf/common/quantization/structs.py @@ -238,7 +238,10 @@ def from_config_dict(cls, config_dict: Dict[str, Any]) -> "QuantizationConstrain ) def constrain_qconfig_list( - self, node_name: NNCFNodeName, target_device: TargetDevice, quantizer_config_list: List[QuantizerConfig] + self, + node_name: NNCFNodeName, + target_device: Optional[TargetDevice], + quantizer_config_list: List[QuantizerConfig], ) -> List[QuantizerConfig]: assert quantizer_config_list is not None diff --git a/nncf/common/scopes.py b/nncf/common/scopes.py index ce3bf6afbea..fe2f24ff7b5 100644 --- a/nncf/common/scopes.py +++ b/nncf/common/scopes.py @@ -10,7 +10,7 @@ # limitations under the License. import re -from typing import List, Optional, Set, Union +from typing import Iterable, List, Optional, Sequence, Union import nncf from nncf.common.graph import NNCFGraph @@ -22,7 +22,7 @@ from nncf.scopes import convert_ignored_scope_to_list -def matches_any(tested_str: str, strs_to_match_to: Union[List[str], Set[str], str]) -> bool: +def matches_any(tested_str: str, strs_to_match_to: Union[Iterable[str], str, None]) -> bool: """ Return True if tested_str matches at least one element in strs_to_match_to. @@ -52,8 +52,8 @@ def matches_any(tested_str: str, strs_to_match_to: Union[List[str], Set[str], st def should_consider_scope( serializable_id: Union[QuantizerId, NNCFNodeName], - ignored_scopes: Union[List[str], Set[str]], - target_scopes: Optional[List[str]] = None, + ignored_scopes: Optional[Sequence[str]], + target_scopes: Optional[Sequence[str]] = None, ) -> bool: """ Used when an entity arising during compression has to be compared to an allowlist or a denylist of strings. diff --git a/nncf/common/utils/dot_file_rw.py b/nncf/common/utils/dot_file_rw.py index f106dad0495..448d282bd97 100644 --- a/nncf/common/utils/dot_file_rw.py +++ b/nncf/common/utils/dot_file_rw.py @@ -11,12 +11,12 @@ import copy import pathlib from collections import defaultdict -from typing import Dict +from typing import Dict, Union import networkx as nx # type: ignore -def write_dot_graph(G: nx.DiGraph, path: pathlib.Path) -> None: +def write_dot_graph(G: nx.DiGraph, path: Union[pathlib.Path, str]) -> None: # NOTE: writing dot files with colons even in labels or other node/edge/graph attributes leads to an # error. See https://github.com/networkx/networkx/issues/5962. If `relabel` is True in this function, # then the colons (:) will be replaced with (^) symbols. diff --git a/pyproject.toml b/pyproject.toml index 84c7b22cede..779f325532a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,10 +119,6 @@ exclude = [ "nncf/common/pruning/utils.py", "nncf/common/pruning/weights_flops_calculator.py", "nncf/common/quantization/config_assignment.py", - "nncf/common/quantization/initialization/range.py", - "nncf/common/quantization/quantizer_propagation/graph.py", - "nncf/common/quantization/quantizer_propagation/grouping.py", - "nncf/common/quantization/quantizer_propagation/solver.py", "nncf/common/quantization/quantizer_removal.py", "nncf/common/quantization/quantizer_setup.py", ]