Skip to content

Commit

Permalink
[mypy] nncf/common/quantization (part 2) (openvinotoolkit#3197)
Browse files Browse the repository at this point in the history
### Changes

Enable mypy for part of files in nncf/common/quantization
  • Loading branch information
AlexanderDokuchaev authored Jan 21, 2025
1 parent d1b5229 commit 0931072
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 181 deletions.
24 changes: 15 additions & 9 deletions nncf/common/quantization/initialization/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from nncf.common.graph.utils import get_reduction_axes
from nncf.common.initialization.dataloader import NNCFDataLoader
Expand All @@ -26,7 +27,12 @@ class RangeInitConfig:
parameters.
"""

def __init__(self, init_type: str, num_init_samples: int, init_type_specific_params: Dict = None):
def __init__(
self,
init_type: str,
num_init_samples: int,
init_type_specific_params: Optional[Dict[str, int]] = None,
):
"""
Initializes the quantization range initialization parameters.
Expand All @@ -43,11 +49,11 @@ def __init__(self, init_type: str, num_init_samples: int, init_type_specific_par
if self.init_type_specific_params is None:
self.init_type_specific_params = {}

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
return self.__dict__ == other.__dict__

@classmethod
def from_dict(cls, dct: Dict) -> "RangeInitConfig":
def from_dict(cls, dct: Dict[str, Any]) -> RangeInitConfig:
num_init_samples = dct.get("num_init_samples", NUM_INIT_SAMPLES)
if num_init_samples < 0:
raise ValueError("Number of initialization samples must be >= 0")
Expand Down Expand Up @@ -94,10 +100,10 @@ def __init__(
self.target_group = target_quantizer_group

@classmethod
def from_dict(cls, dct: Dict) -> "PerLayerRangeInitConfig":
def from_dict(cls, dct: Dict[str, Any]) -> PerLayerRangeInitConfig:
base_config = RangeInitConfig.from_dict(dct)

def get_list(dct: Dict, attr_name: str) -> Optional[List[str]]:
def get_list(dct: Dict[str, Any], attr_name: str) -> Optional[List[str]]:
str_or_list = dct.get(attr_name)
if str_or_list is None:
return None
Expand Down Expand Up @@ -185,7 +191,7 @@ def is_per_channel(self) -> bool:
"""
return self._is_per_channel

def use_per_sample_stats(self, per_sample_stats) -> bool:
def use_per_sample_stats(self, per_sample_stats: bool) -> bool:
"""
For activations, if per_sample_stats is True, statistics will be collected per-sample.
For weights statistics are always collected per-batch.
Expand Down Expand Up @@ -213,7 +219,7 @@ def _get_reduction_axes(
shape_to_reduce: Union[Tuple[int, ...], List[int]],
quantization_axes: Union[Tuple[int, ...], List[int]],
aggregation_axes: Union[Tuple[int, ...], List[int]],
):
) -> Tuple[int, ...]:
"""
Returns axes for a reducer regarding aggregation axes. As aggregator takes axes counting from stacked tensors,
from these axes only tensor related axes should be used for reducer.
Expand All @@ -225,7 +231,7 @@ def _get_reduction_axes(
"""
axes_to_keep = set(el - 1 for el in aggregation_axes if el != 0)
axes_to_keep.update(quantization_axes)
return get_reduction_axes(axes_to_keep, shape_to_reduce)
return get_reduction_axes(list(axes_to_keep), shape_to_reduce)

def _get_aggregation_axes(self, batchwise_statistics: bool) -> Tuple[int, ...]:
"""
Expand Down
146 changes: 83 additions & 63 deletions nncf/common/quantization/quantizer_propagation/graph.py

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions nncf/common/quantization/quantizer_propagation/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class UnifiedScalePropagatingQuantizerGroupManager:
quantized model.
"""

def __init__(self):
def __init__(self) -> None:
self._next_gid = 0
self._group_vs_prop_quants_dict: Dict[int, Set[PropagatingQuantizer]] = {}

Expand All @@ -46,7 +46,7 @@ def register_group(self, prop_quants: Set[PropagatingQuantizer]) -> int:
self._group_vs_prop_quants_dict[gid] = prop_quants
return gid

def add_to_group(self, target_gid: int, prop_quant: PropagatingQuantizer):
def add_to_group(self, target_gid: int, prop_quant: PropagatingQuantizer) -> None:
"""
Adds a propagating quantizer to an already existing group.
Expand All @@ -62,7 +62,7 @@ def add_to_group(self, target_gid: int, prop_quant: PropagatingQuantizer):
)
self._group_vs_prop_quants_dict[target_gid].add(prop_quant)

def remove_from_group(self, group: int, prop_quant: PropagatingQuantizer):
def remove_from_group(self, group: int, prop_quant: PropagatingQuantizer) -> None:
"""
Removes a propagating quantizer from a group.
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_group_id_by_propagating_quantizer_id(self, requested_pqid: int) -> Optio
return gid
return None

def merge_groups(self, merge_to_gid: int, merge_from_gid: int):
def merge_groups(self, merge_to_gid: int, merge_from_gid: int) -> None:
"""
Merges two groups into a single one. The `merge_to_gid` group retains its group ID.
Expand All @@ -110,11 +110,11 @@ class QuantizersWaitingForMergeManager:
and corresponding node keys.
"""

def __init__(self):
def __init__(self) -> None:
self._branching_node_keys_vs_quantizers_waiting_for_merge: Dict[str, Set[PropagatingQuantizer]] = {}
self._quantizers_vs_branching_node_keys: Dict[PropagatingQuantizer, str] = {}

def add_propagating_quantizer_to_wait_on_node_key(self, pq: PropagatingQuantizer, branching_node_key: str):
def add_propagating_quantizer_to_wait_on_node_key(self, pq: PropagatingQuantizer, branching_node_key: str) -> None:
"""
Registers a propagating quantizer as "waiting" on a node in QuantizerPropagationStateGraph.
Expand Down Expand Up @@ -146,10 +146,10 @@ def get_waiting_quantizers_for_branching_node_key(self, node_key: str) -> Set[Pr
"""
return self._branching_node_keys_vs_quantizers_waiting_for_merge[node_key]

def __contains__(self, item: PropagatingQuantizer):
def __contains__(self, item: PropagatingQuantizer) -> bool:
return item in self._quantizers_vs_branching_node_keys

def resolve_merged_node(self, branching_node_key: str):
def resolve_merged_node(self, branching_node_key: str) -> None:
"""
De-registers any quantizers that were previously registered to be "waiting" on a given node key.
:param branching_node_key: The node key in QuantizerPropagationStateGraph that some propagating
Expand Down
Loading

0 comments on commit 0931072

Please sign in to comment.