From 0b8081266403dd63da48909f539268a26a5e472a Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Fri, 17 Jan 2025 16:57:27 +0200 Subject: [PATCH] [mypy] nncf/common/quantization (part 1) (#3192) ### Changes 1. Enable mypy for: - nncf/common/quantization/quantizers.py - nncf/common/quantization/statistics.py - nncf/common/quantization/structs.py - nncf/common/quantization/quantizer_propagation/structs.py - nncf/common/quantization/quantizer_propagation/visualizer.py 2. Inheritance of QuantizationScheme from StrEnum https://github.com/openvinotoolkit/nncf/pull/2629 3. Add QuantizationScheme to Unpickler, to pass pt nightly ### Tests nightly/job/torch_nightly/438/ --- .../torch/common/restricted_pickle_module.py | 1 + nncf/common/hardware/config.py | 2 +- .../quantizer_propagation/structs.py | 22 +++++---- .../quantizer_propagation/visualizer.py | 4 +- nncf/common/quantization/quantizers.py | 2 +- nncf/common/quantization/statistics.py | 6 +-- nncf/common/quantization/structs.py | 45 ++++++++++++------- nncf/common/stateful_classes_registry.py | 12 ++--- pyproject.toml | 5 --- 9 files changed, 56 insertions(+), 43 deletions(-) diff --git a/examples/torch/common/restricted_pickle_module.py b/examples/torch/common/restricted_pickle_module.py index 893f32356c2..d6c0af5239e 100644 --- a/examples/torch/common/restricted_pickle_module.py +++ b/examples/torch/common/restricted_pickle_module.py @@ -38,6 +38,7 @@ class Unpickler(pickle.Unpickler): "torch.nn": {"Module"}, "torch.optim.adam": {"Adam"}, "nncf.api.compression": {"CompressionStage", "CompressionLevel"}, + "nncf.common.quantization.structs": {"QuantizationScheme"}, "numpy.core.multiarray": {"scalar"}, # numpy<2 "numpy._core.multiarray": {"scalar"}, # numpy>=2 "numpy": {"dtype"}, diff --git a/nncf/common/hardware/config.py b/nncf/common/hardware/config.py index 0b1185606f3..799261ced0c 100644 --- a/nncf/common/hardware/config.py +++ b/nncf/common/hardware/config.py @@ -136,7 +136,7 @@ def from_json(cls: type["HWConfig"], path: str) -> List[Dict[str, Any]]: return cls.from_dict(json_config) @staticmethod - def get_quantization_mode_from_config_value(str_val: str) -> str: + def get_quantization_mode_from_config_value(str_val: str) -> QuantizationMode: if str_val == "symmetric": return QuantizationMode.SYMMETRIC if str_val == "asymmetric": diff --git a/nncf/common/quantization/quantizer_propagation/structs.py b/nncf/common/quantization/quantizer_propagation/structs.py index a163ac6844f..741f07c5348 100644 --- a/nncf/common/quantization/quantizer_propagation/structs.py +++ b/nncf/common/quantization/quantizer_propagation/structs.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from enum import Enum from typing import List, Optional, Set, Tuple @@ -67,21 +69,23 @@ def __init__( this quantizer won't require unified scales. """ self.potential_quant_configs: List[QuantizerConfig] = quant_configs - self.affected_edges = set() + self.affected_edges: Set[Tuple[str, str]] = set() self.affected_ip_nodes: Set[str] = set() self.propagation_path: PropagationPath = [] self.current_location_node_key = init_location_node_key - self.last_accepting_location_node_key = None + self.last_accepting_location_node_key: Optional[str] = None self.id = id_ self.unified_scale_type = unified_scale_type - self.affected_operator_nodes = set() - self.quantized_input_sink_operator_nodes = set() - self.downstream_propagating_quantizers = set() + self.affected_operator_nodes: Set[str] = set() + self.quantized_input_sink_operator_nodes: Set[str] = set() + self.downstream_propagating_quantizers: Set[PropagatingQuantizer] = set() - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, PropagatingQuantizer): + return False return self.id == other.id - def __hash__(self): + def __hash__(self) -> int: return hash(self.id) @@ -95,11 +99,11 @@ class QuantizerPropagationStateGraphNodeType(Enum): class SharedAffectedOpsPropagatingQuantizerGroup: """Combines propagating quantizers that share affected operations""" - def __init__(self, affecting_prop_quants: Set[PropagatingQuantizer], affected_op_node_keys: Set[str]): + def __init__(self, affecting_prop_quants: Set[PropagatingQuantizer], affected_op_node_keys: Set[str]) -> None: self.affecting_prop_quants: Set[PropagatingQuantizer] = affecting_prop_quants self.affected_op_node_keys: Set[str] = affected_op_node_keys - def update(self, other: "SharedAffectedOpsPropagatingQuantizerGroup"): + def update(self, other: SharedAffectedOpsPropagatingQuantizerGroup) -> None: self.affected_op_node_keys.update(other.affected_op_node_keys) self.affecting_prop_quants.update(other.affecting_prop_quants) diff --git a/nncf/common/quantization/quantizer_propagation/visualizer.py b/nncf/common/quantization/quantizer_propagation/visualizer.py index 817cf409119..b39b392b5af 100644 --- a/nncf/common/quantization/quantizer_propagation/visualizer.py +++ b/nncf/common/quantization/quantizer_propagation/visualizer.py @@ -20,14 +20,14 @@ class QuantizerPropagationVisualizer: An object performing visualization of the quantizer propagation algorithm's state into a chosen directory. """ - def __init__(self, dump_dir: str = None): + def __init__(self, dump_dir: str): self.dump_dir = Path(dump_dir) if self.dump_dir.exists(): shutil.rmtree(str(self.dump_dir)) def visualize_quantizer_propagation( self, prop_solver: QuantizerPropagationSolver, prop_graph: QuantizerPropagationStateGraph, iteration: str - ): + ) -> None: self.dump_dir.mkdir(parents=True, exist_ok=True) fname = "quant_prop_iter_{}.dot".format(iteration) prop_solver.debug_visualize(prop_graph, str(self.dump_dir / Path(fname))) diff --git a/nncf/common/quantization/quantizers.py b/nncf/common/quantization/quantizers.py index 7ebddb11e5d..45e306adf0d 100644 --- a/nncf/common/quantization/quantizers.py +++ b/nncf/common/quantization/quantizers.py @@ -66,5 +66,5 @@ def calculate_asymmetric_level_ranges(num_bits: int, narrow_range: bool = False) return level_low, level_high -def get_num_levels(level_low: int, level_high: int): +def get_num_levels(level_low: int, level_high: int) -> int: return level_high - level_low + 1 diff --git a/nncf/common/quantization/statistics.py b/nncf/common/quantization/statistics.py index bac232d9374..09ac1969fdc 100644 --- a/nncf/common/quantization/statistics.py +++ b/nncf/common/quantization/statistics.py @@ -16,7 +16,7 @@ from nncf.common.utils.helpers import create_table -def _proportion_str(num: int, total_count: int): +def _proportion_str(num: int, total_count: int) -> str: percentage = 100 * (num / max(total_count, 1)) return f"{percentage:.2f} % ({num} / {total_count})" @@ -170,12 +170,12 @@ def _get_bitwidth_distribution_str(self) -> str: q_total_num = wq_total_num + aq_total_num bitwidths = self.num_wq_per_bitwidth.keys() | self.num_aq_per_bitwidth.keys() # union of all bitwidths - bitwidths = sorted(bitwidths, reverse=True) + bitwidths_sorted = sorted(bitwidths, reverse=True) # Table creation header = ["Num bits (N)", "N-bits WQs / Placed WQs", "N-bits AQs / Placed AQs", "N-bits Qs / Placed Qs"] rows = [] - for bitwidth in bitwidths: + for bitwidth in bitwidths_sorted: wq_num = self.num_wq_per_bitwidth.get(bitwidth, 0) # for current bitwidth aq_num = self.num_aq_per_bitwidth.get(bitwidth, 0) # for current bitwidth q_num = wq_num + aq_num # for current bitwidth diff --git a/nncf/common/quantization/structs.py b/nncf/common/quantization/structs.py index 4d4806cc555..507f0779653 100644 --- a/nncf/common/quantization/structs.py +++ b/nncf/common/quantization/structs.py @@ -11,7 +11,7 @@ from copy import deepcopy from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import nncf from nncf.common.graph import NNCFNode @@ -24,7 +24,7 @@ @api() -class QuantizationScheme: +class QuantizationScheme(StrEnum): """ Basic enumeration for quantization scheme specification. @@ -45,7 +45,7 @@ class QuantizerConfig: def __init__( self, num_bits: int = QUANTIZATION_BITS, - mode: Union[QuantizationScheme, str] = QuantizationScheme.SYMMETRIC, # TODO(AlexanderDokuchaev): use enum + mode: QuantizationScheme = QuantizationScheme.SYMMETRIC, signedness_to_force: Optional[bool] = None, per_channel: bool = QUANTIZATION_PER_CHANNEL, ): @@ -62,10 +62,12 @@ def __init__( self.signedness_to_force = signedness_to_force self.per_channel = per_channel - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, QuantizerConfig): + return False return self.__dict__ == other.__dict__ - def __str__(self): + def __str__(self) -> str: return "B:{bits} M:{mode} SGN:{signedness} PC:{per_channel}".format( bits=self.num_bits, mode="S" if self.mode == QuantizationScheme.SYMMETRIC else "A", @@ -73,7 +75,7 @@ def __str__(self): per_channel="Y" if self.per_channel else "N", ) - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) def is_valid_requantization_for(self, other: "QuantizerConfig") -> bool: @@ -96,7 +98,7 @@ def is_valid_requantization_for(self, other: "QuantizerConfig") -> bool: return False return True - def compatible_with_a_unified_scale_linked_qconfig(self, linked_qconfig: "QuantizerConfig"): + def compatible_with_a_unified_scale_linked_qconfig(self, linked_qconfig: "QuantizerConfig") -> bool: """ For two configs to be compatible in a unified scale scenario, all of their fundamental parameters must be aligned. @@ -155,7 +157,12 @@ class QuantizerSpec: """ def __init__( - self, num_bits: int, mode: QuantizationScheme, signedness_to_force: bool, narrow_range: bool, half_range: bool + self, + num_bits: int, + mode: QuantizationScheme, + signedness_to_force: Optional[bool], + narrow_range: Optional[bool], + half_range: bool, ): """ :param num_bits: Bitwidth of the quantization. @@ -174,7 +181,9 @@ def __init__( self.narrow_range = narrow_range self.half_range = half_range - def __eq__(self, other: "QuantizerSpec"): + def __eq__(self, other: object) -> bool: + if not isinstance(other, QuantizerSpec): + return False return self.__dict__ == other.__dict__ @classmethod @@ -185,7 +194,7 @@ def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: b class QuantizationConstraints: REF_QCONF_OBJ = QuantizerConfig() - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: """ Use attribute names of QuantizerConfig as arguments to set up constraints. @@ -220,7 +229,7 @@ def get_updated_constraints(self, overriding_constraints: "QuantizationConstrain return QuantizationConstraints(**new_dict) @classmethod - def from_config_dict(cls, config_dict: Dict) -> "QuantizationConstraints": + def from_config_dict(cls, config_dict: Dict[str, Any]) -> "QuantizationConstraints": return cls( num_bits=config_dict.get("bits"), mode=config_dict.get("mode"), @@ -264,19 +273,21 @@ class QuantizerId: structure. """ - def get_base(self): + def get_base(self) -> str: raise NotImplementedError def get_suffix(self) -> str: raise NotImplementedError - def __str__(self): + def __str__(self) -> str: return str(self.get_base()) + self.get_suffix() - def __hash__(self): + def __hash__(self) -> int: return hash((self.get_base(), self.get_suffix())) - def __eq__(self, other: "QuantizerId"): + def __eq__(self, other: object) -> bool: + if not isinstance(other, QuantizerId): + return False return (self.get_base() == other.get_base()) and (self.get_suffix() == other.get_suffix()) @@ -299,7 +310,7 @@ class NonWeightQuantizerId(QuantizerId): ordinary activation, function and input """ - def __init__(self, target_node_name: NNCFNodeName, input_port_id=None): + def __init__(self, target_node_name: NNCFNodeName, input_port_id: Optional[int] = None): self.target_node_name = target_node_name self.input_port_id = input_port_id @@ -335,7 +346,7 @@ class QuantizationPreset(StrEnum): PERFORMANCE = "performance" MIXED = "mixed" - def get_params_configured_by_preset(self, quant_group: QuantizerGroup) -> Dict: + def get_params_configured_by_preset(self, quant_group: QuantizerGroup) -> Dict[str, str]: if quant_group == QuantizerGroup.ACTIVATIONS and self == QuantizationPreset.MIXED: return {"mode": QuantizationScheme.ASYMMETRIC} return {"mode": QuantizationScheme.SYMMETRIC} diff --git a/nncf/common/stateful_classes_registry.py b/nncf/common/stateful_classes_registry.py index dd1983280ed..7a03a63c27b 100644 --- a/nncf/common/stateful_classes_registry.py +++ b/nncf/common/stateful_classes_registry.py @@ -10,7 +10,9 @@ # limitations under the License. import inspect -from typing import Callable, Dict +from typing import Callable, Dict, TypeVar + +TObj = TypeVar("TObj", bound=type) class StatefulClassesRegistry: @@ -24,7 +26,7 @@ def __init__(self) -> None: self._name_vs_class_map: Dict[str, type] = {} self._class_vs_name_map: Dict[type, str] = {} - def register(self, name: str = None) -> Callable[[type], type]: + def register(self, name: str = None) -> Callable[[TObj], TObj]: """ Decorator to map class with some name - specified in the argument or name of the class. @@ -32,7 +34,7 @@ def register(self, name: str = None) -> Callable[[type], type]: :return: The inner function for registration. """ - def decorator(cls: type) -> type: + def decorator(cls: TObj) -> TObj: registered_name = name if name is not None else cls.__name__ if registered_name in self._name_vs_class_map: @@ -88,7 +90,7 @@ class CommonStatefulClassesRegistry: """ @staticmethod - def register(name: str = None) -> Callable[[type], type]: + def register(name: str = None) -> Callable[[TObj], TObj]: """ Decorator to map class with some name - specified in the argument or name of the class. @@ -96,7 +98,7 @@ def register(name: str = None) -> Callable[[type], type]: :return: The inner function for registration. """ - def decorator(cls: type) -> type: + def decorator(cls: TObj) -> TObj: PT_STATEFUL_CLASSES.register(name)(cls) TF_STATEFUL_CLASSES.register(name)(cls) return cls diff --git a/pyproject.toml b/pyproject.toml index 64853599d58..84c7b22cede 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,13 +123,8 @@ exclude = [ "nncf/common/quantization/quantizer_propagation/graph.py", "nncf/common/quantization/quantizer_propagation/grouping.py", "nncf/common/quantization/quantizer_propagation/solver.py", - "nncf/common/quantization/quantizer_propagation/structs.py", - "nncf/common/quantization/quantizer_propagation/visualizer.py", "nncf/common/quantization/quantizer_removal.py", "nncf/common/quantization/quantizer_setup.py", - "nncf/common/quantization/quantizers.py", - "nncf/common/quantization/statistics.py", - "nncf/common/quantization/structs.py", ] [tool.ruff]