Skip to content

Commit

Permalink
[mypy] nncf/common/quantization (part 1) (openvinotoolkit#3192)
Browse files Browse the repository at this point in the history
### Changes

1. Enable mypy for:
    - nncf/common/quantization/quantizers.py
    - nncf/common/quantization/statistics.py
    - nncf/common/quantization/structs.py
    - nncf/common/quantization/quantizer_propagation/structs.py
    - nncf/common/quantization/quantizer_propagation/visualizer.py
2. Inheritance of QuantizationScheme from StrEnum
openvinotoolkit#2629
3. Add QuantizationScheme to Unpickler, to pass pt nightly 


### Tests

nightly/job/torch_nightly/438/
  • Loading branch information
AlexanderDokuchaev authored Jan 17, 2025
1 parent 98f4060 commit 0b80812
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 43 deletions.
1 change: 1 addition & 0 deletions examples/torch/common/restricted_pickle_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Unpickler(pickle.Unpickler):
"torch.nn": {"Module"},
"torch.optim.adam": {"Adam"},
"nncf.api.compression": {"CompressionStage", "CompressionLevel"},
"nncf.common.quantization.structs": {"QuantizationScheme"},
"numpy.core.multiarray": {"scalar"}, # numpy<2
"numpy._core.multiarray": {"scalar"}, # numpy>=2
"numpy": {"dtype"},
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/hardware/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def from_json(cls: type["HWConfig"], path: str) -> List[Dict[str, Any]]:
return cls.from_dict(json_config)

@staticmethod
def get_quantization_mode_from_config_value(str_val: str) -> str:
def get_quantization_mode_from_config_value(str_val: str) -> QuantizationMode:
if str_val == "symmetric":
return QuantizationMode.SYMMETRIC
if str_val == "asymmetric":
Expand Down
22 changes: 13 additions & 9 deletions nncf/common/quantization/quantizer_propagation/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from enum import Enum
from typing import List, Optional, Set, Tuple

Expand Down Expand Up @@ -67,21 +69,23 @@ def __init__(
this quantizer won't require unified scales.
"""
self.potential_quant_configs: List[QuantizerConfig] = quant_configs
self.affected_edges = set()
self.affected_edges: Set[Tuple[str, str]] = set()
self.affected_ip_nodes: Set[str] = set()
self.propagation_path: PropagationPath = []
self.current_location_node_key = init_location_node_key
self.last_accepting_location_node_key = None
self.last_accepting_location_node_key: Optional[str] = None
self.id = id_
self.unified_scale_type = unified_scale_type
self.affected_operator_nodes = set()
self.quantized_input_sink_operator_nodes = set()
self.downstream_propagating_quantizers = set()
self.affected_operator_nodes: Set[str] = set()
self.quantized_input_sink_operator_nodes: Set[str] = set()
self.downstream_propagating_quantizers: Set[PropagatingQuantizer] = set()

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, PropagatingQuantizer):
return False
return self.id == other.id

def __hash__(self):
def __hash__(self) -> int:
return hash(self.id)


Expand All @@ -95,11 +99,11 @@ class QuantizerPropagationStateGraphNodeType(Enum):
class SharedAffectedOpsPropagatingQuantizerGroup:
"""Combines propagating quantizers that share affected operations"""

def __init__(self, affecting_prop_quants: Set[PropagatingQuantizer], affected_op_node_keys: Set[str]):
def __init__(self, affecting_prop_quants: Set[PropagatingQuantizer], affected_op_node_keys: Set[str]) -> None:
self.affecting_prop_quants: Set[PropagatingQuantizer] = affecting_prop_quants
self.affected_op_node_keys: Set[str] = affected_op_node_keys

def update(self, other: "SharedAffectedOpsPropagatingQuantizerGroup"):
def update(self, other: SharedAffectedOpsPropagatingQuantizerGroup) -> None:
self.affected_op_node_keys.update(other.affected_op_node_keys)
self.affecting_prop_quants.update(other.affecting_prop_quants)

Expand Down
4 changes: 2 additions & 2 deletions nncf/common/quantization/quantizer_propagation/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ class QuantizerPropagationVisualizer:
An object performing visualization of the quantizer propagation algorithm's state into a chosen directory.
"""

def __init__(self, dump_dir: str = None):
def __init__(self, dump_dir: str):
self.dump_dir = Path(dump_dir)
if self.dump_dir.exists():
shutil.rmtree(str(self.dump_dir))

def visualize_quantizer_propagation(
self, prop_solver: QuantizerPropagationSolver, prop_graph: QuantizerPropagationStateGraph, iteration: str
):
) -> None:
self.dump_dir.mkdir(parents=True, exist_ok=True)
fname = "quant_prop_iter_{}.dot".format(iteration)
prop_solver.debug_visualize(prop_graph, str(self.dump_dir / Path(fname)))
2 changes: 1 addition & 1 deletion nncf/common/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,5 @@ def calculate_asymmetric_level_ranges(num_bits: int, narrow_range: bool = False)
return level_low, level_high


def get_num_levels(level_low: int, level_high: int):
def get_num_levels(level_low: int, level_high: int) -> int:
return level_high - level_low + 1
6 changes: 3 additions & 3 deletions nncf/common/quantization/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from nncf.common.utils.helpers import create_table


def _proportion_str(num: int, total_count: int):
def _proportion_str(num: int, total_count: int) -> str:
percentage = 100 * (num / max(total_count, 1))
return f"{percentage:.2f} % ({num} / {total_count})"

Expand Down Expand Up @@ -170,12 +170,12 @@ def _get_bitwidth_distribution_str(self) -> str:
q_total_num = wq_total_num + aq_total_num

bitwidths = self.num_wq_per_bitwidth.keys() | self.num_aq_per_bitwidth.keys() # union of all bitwidths
bitwidths = sorted(bitwidths, reverse=True)
bitwidths_sorted = sorted(bitwidths, reverse=True)

# Table creation
header = ["Num bits (N)", "N-bits WQs / Placed WQs", "N-bits AQs / Placed AQs", "N-bits Qs / Placed Qs"]
rows = []
for bitwidth in bitwidths:
for bitwidth in bitwidths_sorted:
wq_num = self.num_wq_per_bitwidth.get(bitwidth, 0) # for current bitwidth
aq_num = self.num_aq_per_bitwidth.get(bitwidth, 0) # for current bitwidth
q_num = wq_num + aq_num # for current bitwidth
Expand Down
45 changes: 28 additions & 17 deletions nncf/common/quantization/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional

import nncf
from nncf.common.graph import NNCFNode
Expand All @@ -24,7 +24,7 @@


@api()
class QuantizationScheme:
class QuantizationScheme(StrEnum):
"""
Basic enumeration for quantization scheme specification.
Expand All @@ -45,7 +45,7 @@ class QuantizerConfig:
def __init__(
self,
num_bits: int = QUANTIZATION_BITS,
mode: Union[QuantizationScheme, str] = QuantizationScheme.SYMMETRIC, # TODO(AlexanderDokuchaev): use enum
mode: QuantizationScheme = QuantizationScheme.SYMMETRIC,
signedness_to_force: Optional[bool] = None,
per_channel: bool = QUANTIZATION_PER_CHANNEL,
):
Expand All @@ -62,18 +62,20 @@ def __init__(
self.signedness_to_force = signedness_to_force
self.per_channel = per_channel

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, QuantizerConfig):
return False
return self.__dict__ == other.__dict__

def __str__(self):
def __str__(self) -> str:
return "B:{bits} M:{mode} SGN:{signedness} PC:{per_channel}".format(
bits=self.num_bits,
mode="S" if self.mode == QuantizationScheme.SYMMETRIC else "A",
signedness="ANY" if self.signedness_to_force is None else ("S" if self.signedness_to_force else "U"),
per_channel="Y" if self.per_channel else "N",
)

def __hash__(self):
def __hash__(self) -> int:
return hash(str(self))

def is_valid_requantization_for(self, other: "QuantizerConfig") -> bool:
Expand All @@ -96,7 +98,7 @@ def is_valid_requantization_for(self, other: "QuantizerConfig") -> bool:
return False
return True

def compatible_with_a_unified_scale_linked_qconfig(self, linked_qconfig: "QuantizerConfig"):
def compatible_with_a_unified_scale_linked_qconfig(self, linked_qconfig: "QuantizerConfig") -> bool:
"""
For two configs to be compatible in a unified scale scenario, all of their fundamental parameters
must be aligned.
Expand Down Expand Up @@ -155,7 +157,12 @@ class QuantizerSpec:
"""

def __init__(
self, num_bits: int, mode: QuantizationScheme, signedness_to_force: bool, narrow_range: bool, half_range: bool
self,
num_bits: int,
mode: QuantizationScheme,
signedness_to_force: Optional[bool],
narrow_range: Optional[bool],
half_range: bool,
):
"""
:param num_bits: Bitwidth of the quantization.
Expand All @@ -174,7 +181,9 @@ def __init__(
self.narrow_range = narrow_range
self.half_range = half_range

def __eq__(self, other: "QuantizerSpec"):
def __eq__(self, other: object) -> bool:
if not isinstance(other, QuantizerSpec):
return False
return self.__dict__ == other.__dict__

@classmethod
Expand All @@ -185,7 +194,7 @@ def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: b
class QuantizationConstraints:
REF_QCONF_OBJ = QuantizerConfig()

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
"""
Use attribute names of QuantizerConfig as arguments
to set up constraints.
Expand Down Expand Up @@ -220,7 +229,7 @@ def get_updated_constraints(self, overriding_constraints: "QuantizationConstrain
return QuantizationConstraints(**new_dict)

@classmethod
def from_config_dict(cls, config_dict: Dict) -> "QuantizationConstraints":
def from_config_dict(cls, config_dict: Dict[str, Any]) -> "QuantizationConstraints":
return cls(
num_bits=config_dict.get("bits"),
mode=config_dict.get("mode"),
Expand Down Expand Up @@ -264,19 +273,21 @@ class QuantizerId:
structure.
"""

def get_base(self):
def get_base(self) -> str:
raise NotImplementedError

def get_suffix(self) -> str:
raise NotImplementedError

def __str__(self):
def __str__(self) -> str:
return str(self.get_base()) + self.get_suffix()

def __hash__(self):
def __hash__(self) -> int:
return hash((self.get_base(), self.get_suffix()))

def __eq__(self, other: "QuantizerId"):
def __eq__(self, other: object) -> bool:
if not isinstance(other, QuantizerId):
return False
return (self.get_base() == other.get_base()) and (self.get_suffix() == other.get_suffix())


Expand All @@ -299,7 +310,7 @@ class NonWeightQuantizerId(QuantizerId):
ordinary activation, function and input
"""

def __init__(self, target_node_name: NNCFNodeName, input_port_id=None):
def __init__(self, target_node_name: NNCFNodeName, input_port_id: Optional[int] = None):
self.target_node_name = target_node_name
self.input_port_id = input_port_id

Expand Down Expand Up @@ -335,7 +346,7 @@ class QuantizationPreset(StrEnum):
PERFORMANCE = "performance"
MIXED = "mixed"

def get_params_configured_by_preset(self, quant_group: QuantizerGroup) -> Dict:
def get_params_configured_by_preset(self, quant_group: QuantizerGroup) -> Dict[str, str]:
if quant_group == QuantizerGroup.ACTIVATIONS and self == QuantizationPreset.MIXED:
return {"mode": QuantizationScheme.ASYMMETRIC}
return {"mode": QuantizationScheme.SYMMETRIC}
12 changes: 7 additions & 5 deletions nncf/common/stateful_classes_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
# limitations under the License.

import inspect
from typing import Callable, Dict
from typing import Callable, Dict, TypeVar

TObj = TypeVar("TObj", bound=type)


class StatefulClassesRegistry:
Expand All @@ -24,15 +26,15 @@ def __init__(self) -> None:
self._name_vs_class_map: Dict[str, type] = {}
self._class_vs_name_map: Dict[type, str] = {}

def register(self, name: str = None) -> Callable[[type], type]:
def register(self, name: str = None) -> Callable[[TObj], TObj]:
"""
Decorator to map class with some name - specified in the argument or name of the class.
:param name: The registration name. By default, it's name of the class.
:return: The inner function for registration.
"""

def decorator(cls: type) -> type:
def decorator(cls: TObj) -> TObj:
registered_name = name if name is not None else cls.__name__

if registered_name in self._name_vs_class_map:
Expand Down Expand Up @@ -88,15 +90,15 @@ class CommonStatefulClassesRegistry:
"""

@staticmethod
def register(name: str = None) -> Callable[[type], type]:
def register(name: str = None) -> Callable[[TObj], TObj]:
"""
Decorator to map class with some name - specified in the argument or name of the class.
:param name: The registration name. By default, it's name of the class.
:return: The inner function for registration.
"""

def decorator(cls: type) -> type:
def decorator(cls: TObj) -> TObj:
PT_STATEFUL_CLASSES.register(name)(cls)
TF_STATEFUL_CLASSES.register(name)(cls)
return cls
Expand Down
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,8 @@ exclude = [
"nncf/common/quantization/quantizer_propagation/graph.py",
"nncf/common/quantization/quantizer_propagation/grouping.py",
"nncf/common/quantization/quantizer_propagation/solver.py",
"nncf/common/quantization/quantizer_propagation/structs.py",
"nncf/common/quantization/quantizer_propagation/visualizer.py",
"nncf/common/quantization/quantizer_removal.py",
"nncf/common/quantization/quantizer_setup.py",
"nncf/common/quantization/quantizers.py",
"nncf/common/quantization/statistics.py",
"nncf/common/quantization/structs.py",
]

[tool.ruff]
Expand Down

0 comments on commit 0b80812

Please sign in to comment.