diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index a735ad59cb9..eae40dbffe7 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -268,22 +268,22 @@ def _create_quantizer( quantizer = quantizer_cls(quantizer_spec) # Fill it with minmax - PTMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters) + PTMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape) return quantizer @staticmethod - def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters) -> None: + def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters, scale_shape) -> None: if isinstance(quantizer, AsymmetricQuantizer): - quantizer.input_low = torch.nn.Parameter(parameters.input_low.data) + quantizer.input_low = torch.nn.Parameter(parameters.input_low.data.reshape(scale_shape)) input_range = parameters.input_high - parameters.input_low # Subtract eps from the input_range to make quantizer parameters equal to # original parameters on the forward call. - quantizer.input_range = torch.nn.Parameter(input_range.data - quantizer.eps) + quantizer.input_range = torch.nn.Parameter((input_range.data - quantizer.eps).reshape(scale_shape)) else: quantizer.signed = bool(torch.any(parameters.input_low.data < 0)) # Subtract eps from the scale to make quantizer parameters equal to # original parameters on the forward call. - quantizer.scale = torch.nn.Parameter(parameters.input_high.data - quantizer.eps) + quantizer.scale = torch.nn.Parameter((parameters.input_high.data - quantizer.eps).reshape(scale_shape)) @staticmethod def create_quantizer_insertion_command( diff --git a/nncf/quantization/algorithms/smooth_quant/torch_backend.py b/nncf/quantization/algorithms/smooth_quant/torch_backend.py index 19154231dbc..5db48a24f85 100644 --- a/nncf/quantization/algorithms/smooth_quant/torch_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/torch_backend.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Tuple +from typing import Any, Callable, Dict, List, Tuple import numpy as np import torch @@ -30,6 +30,9 @@ from nncf.torch.graph.transformations.command_creation import create_command_to_update_weight from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.layer_utils import COMPRESSION_MODULES +from nncf.torch.layer_utils import CompressionParameter +from nncf.torch.layer_utils import StatefullModuleInterface from nncf.torch.model_graph_manager import get_const_data from nncf.torch.model_graph_manager import get_const_node from nncf.torch.nncf_network import NNCFNetwork @@ -38,14 +41,32 @@ from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor -class SQMultiply(torch.nn.Module): - def __init__(self, scale_value): +@COMPRESSION_MODULES.register() +class SQMultiply(torch.nn.Module, StatefullModuleInterface): + SCALE_SHAPE_KEY = "scale_shape" + + def __init__(self, scale_shape: Tuple[int, ...]): super().__init__() - self._scale_value = scale_value + self._scale_value = CompressionParameter(torch.empty(scale_shape)) + + @property + def scale(self) -> torch.nn.Parameter: + return self._scale_value - def forward(self, x): + @scale.setter + def scale(self, value: torch.tensor): + self._scale_value.data = value + + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.mul(x, self._scale_value) + def get_config(self) -> Dict[str, Any]: + return {self.SCALE_SHAPE_KEY: list(self._scale_value.shape)} + + @classmethod + def from_config(cls, state) -> "SQMultiply": + return SQMultiply(state[cls.SCALE_SHAPE_KEY]) + PT_PRE_LAYER_TARGET_TYPE = TargetType.OPERATOR_PRE_HOOK @@ -122,7 +143,7 @@ def weight_update_command(node_with_weight: NNCFNode, weight_value: np.ndarray) @staticmethod def scale_insertion_command( source_node: NNCFNode, - scale_value: np.ndarray, + scale_value: torch.Tensor, source_output_port_id: int, nodes: List[NNCFNode], scale_node_name: str, @@ -132,7 +153,9 @@ def scale_insertion_command( for node in nodes: target_points.append(PTTargetPoint(PT_PRE_LAYER_TARGET_TYPE, node.node_name, input_port_id=input_port_id)) - return PTSharedFnInsertionCommand(target_points, SQMultiply(scale_value), scale_node_name) + sq_multiply = SQMultiply(scale_value.shape) + sq_multiply.scale = scale_value + return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name) @staticmethod def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int: diff --git a/nncf/torch/graph/transformations/serialization.py b/nncf/torch/graph/transformations/serialization.py new file mode 100644 index 00000000000..282c59453eb --- /dev/null +++ b/nncf/torch/graph/transformations/serialization.py @@ -0,0 +1,121 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Any, Dict, Union + +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.commands import PTTransformationCommand +from nncf.torch.layer_utils import COMPRESSION_MODULES + +COMPRESSION_STATE_ATTR = "compression_state" +SUPPORTED_COMMANDS = (PTSharedFnInsertionCommand, PTInsertionCommand) + + +def serialize_transformations(transformations_layout: TransformationLayout) -> Dict[str, Any]: + """ + Serializes given transformation layout to a dict. + + :param tranformation_layout: Given transformation layout. + :return: Serialized representation of given transformation layout as a dict. + """ + transformation_commands = [] + for command in transformations_layout.transformations: + transformation_commands.append(serialize_command(command)) + + return {COMPRESSION_STATE_ATTR: transformation_commands} + + +def deserialize_transformations(serialized_transformation_layout: Dict[str, Any]) -> TransformationLayout: + """ + Deserializes given serialized transformation layout. + + :param serialized_transformation_layout: Given serialized transformation layout. + :return: The deserialized transformation layout. + """ + transformation_layout = TransformationLayout() + for serialized_command in serialized_transformation_layout[COMPRESSION_STATE_ATTR]: + command = deserialize_command(serialized_command) + transformation_layout.register(command) + + return transformation_layout + + +def serialize_command(command: PTTransformationCommand) -> Dict[str, Any]: + """ + Serializes given command layout to a dict. + + :param command: Given command. + :return: Serialized representation of given command as a dict. + """ + if not isinstance(command, SUPPORTED_COMMANDS): + raise RuntimeError(f"Command type {command.__class__} is not supported.") + + serialized_transformation = dict() + serialized_transformation["type"] = command.__class__.__name__ + if isinstance(command, PTSharedFnInsertionCommand): + serialized_transformation["target_points"] = [point.get_state() for point in command.target_points] + serialized_transformation["op_name"] = command.op_name + serialized_transformation["compression_module_type"] = command.compression_module_type.value + elif isinstance(command, PTInsertionCommand): + serialized_transformation["target_point"] = command.target_point.get_state() + + # Check compression module is registered + compression_module_name = command.fn.__class__.__name__ + if compression_module_name not in COMPRESSION_MODULES.registry_dict: + raise RuntimeError( + f"Could not serialize compression module with name {compression_module_name}." + " Please register your module in the COMPRESSION_MODULES registry." + ) + serialized_transformation["compression_module_name"] = compression_module_name + serialized_transformation["fn_config"] = command.fn.get_config() + serialized_transformation["hooks_group_name"] = command.hooks_group_name + priority = command.priority + serialized_transformation["priority"] = priority.value if isinstance(priority, Enum) else priority + return serialized_transformation + + +def deserialize_command(serialized_command: Dict[str, Any]) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]: + """ + Deserializes given serialized command. + + :param serialized_command: Given serialized command. + :return: The deserialized command. + """ + if serialized_command["type"] not in (command_cls.__name__ for command_cls in SUPPORTED_COMMANDS): + raise RuntimeError(f"Command type {serialized_command['type']} is not supported.") + + module_cls = COMPRESSION_MODULES.get(serialized_command["compression_module_name"]) + fn = module_cls.from_config(serialized_command["fn_config"]) + priority = serialized_command["priority"] + if priority in iter(TransformationPriority): + priority = TransformationPriority(priority) + + if serialized_command["type"] == PTInsertionCommand.__name__: + target_point = PTTargetPoint.from_state(serialized_command["target_point"]) + return PTInsertionCommand( + point=target_point, fn=fn, priority=priority, hooks_group_name=serialized_command["hooks_group_name"] + ) + + target_points = [PTTargetPoint.from_state(state) for state in serialized_command["target_points"]] + return PTSharedFnInsertionCommand( + target_points=target_points, + fn=fn, + op_unique_name=serialized_command["op_name"], + compression_module_type=ExtraCompressionModuleType(serialized_command["compression_module_type"]), + priority=priority, + hooks_group_name=serialized_command["hooks_group_name"], + ) diff --git a/nncf/torch/layer_utils.py b/nncf/torch/layer_utils.py index 85756f13dcc..0614d5fd2ea 100644 --- a/nncf/torch/layer_utils.py +++ b/nncf/torch/layer_utils.py @@ -9,6 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC +from abc import abstractclassmethod +from abc import abstractmethod +from typing import Any, Dict + import torch from torch import nn @@ -19,6 +24,33 @@ COMPRESSION_MODULES = Registry("compression modules") +class StatefullModuleInterface(ABC): + """ + Interface that should be implemented for every registered compression module to make it possible + to save an compression modules state and create an compression module from the saved state. + Config of the module should be json serializable, no python objects except + standart (str, list and etc.) should be present in a compression module config. + Values for attributes with type torch.nn.Parameter + is recovered from the model `state_dict`, so there is no need to keep them in the module config. + Modules should avoid implementation of `__call__` method and use `forward` method instead, + as torch functions called inside the `__call__` method could not be unambiguously + separated from the wrapped parent nncf module functions calls, thus nncf is unable to + identify target point for that call during transformations recovery process. + """ + + @abstractmethod + def get_config(self) -> Dict[str, Any]: + """ + Returns the compression module config. + """ + + @abstractclassmethod + def from_config(cls, state: Dict[str, Any]) -> object: + """ + Creates a compression module instance from the given config. + """ + + class ProxyModule: def __init__(self, module): self._module = module @@ -117,7 +149,12 @@ def __init__(self, data: torch.Tensor = None, requires_grad: bool = True, compre """ super().__init__() + self._compression_lr_multiplier = compression_lr_multiplier if compression_lr_multiplier is not None and self.dtype.is_floating_point: self.requires_grad = True self.register_hook(lambda grad: compression_lr_multiplier * grad) self.requires_grad = requires_grad + + @property + def compression_lr_multiplier(self): + return self._compression_lr_multiplier diff --git a/nncf/torch/pruning/filter_pruning/layers.py b/nncf/torch/pruning/filter_pruning/layers.py index 2d95a9d982f..644a5eec7f5 100644 --- a/nncf/torch/pruning/filter_pruning/layers.py +++ b/nncf/torch/pruning/filter_pruning/layers.py @@ -8,6 +8,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from typing import Any, Dict + import numpy as np import torch from torch import nn @@ -15,15 +18,20 @@ import nncf from nncf.common.graph import NNCFNodeName from nncf.torch.layer_utils import COMPRESSION_MODULES +from nncf.torch.layer_utils import StatefullModuleInterface @COMPRESSION_MODULES.register() -class FilterPruningMask(nn.Module): +class FilterPruningMask(nn.Module, StatefullModuleInterface): """ A module contains the mask for pruning. On forward pass applying the mask to weight and bias of the module. """ + MASK_APPLYING_DIM_KEY = "dim" + NODE_NAME_KEY = "node_name" + SIZE_KEY = "size_key" + def __init__(self, size, node_name, dim=0): super().__init__() self.register_buffer("_binary_filter_pruning_mask", torch.ones(size)) @@ -31,11 +39,11 @@ def __init__(self, size, node_name, dim=0): self.node_name = node_name @property - def binary_filter_pruning_mask(self): + def binary_filter_pruning_mask(self) -> torch.Tensor: return self._binary_filter_pruning_mask @binary_filter_pruning_mask.setter - def binary_filter_pruning_mask(self, mask): + def binary_filter_pruning_mask(self, mask: torch.Tensor): with torch.no_grad(): self._binary_filter_pruning_mask.set_(mask) @@ -56,6 +64,19 @@ def forward(self, **params): ) return new_params + def get_config(self) -> Dict[str, Any]: + return { + self.MASK_APPLYING_DIM_KEY: self.mask_applying_dim, + self.NODE_NAME_KEY: self.node_name, + self.SIZE_KEY: list(self.binary_filter_pruning_mask.size()), + } + + @classmethod + def from_config(cls, state: Dict[str, Any]) -> "FilterPruningMask": + return FilterPruningMask( + size=state[cls.SIZE_KEY], node_name=state[cls.NODE_NAME_KEY], dim=state[cls.MASK_APPLYING_DIM_KEY] + ) + def broadcast_filter_mask(filter_mask, shape, dim=0): broadcasted_shape = np.ones(len(shape), dtype=np.int64) diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index cb8906b1ff0..4b463600bc5 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -41,6 +41,7 @@ from nncf.torch.graph.transformations.commands import TargetType from nncf.torch.layer_utils import COMPRESSION_MODULES from nncf.torch.layer_utils import CompressionParameter +from nncf.torch.layer_utils import StatefullModuleInterface from nncf.torch.quantization.quantize_functions import ExportQuantizeToFakeQuantize from nncf.torch.quantization.quantize_functions import ExportQuantizeToONNXQuantDequant from nncf.torch.quantization.quantize_functions import TuneRange @@ -283,9 +284,10 @@ def add_quantization_point(self, qp_id: QuantizationPointId, qp: PTQuantizationP self.quantization_points[qp_id] = qp -class BaseQuantizer(nn.Module, ABC): +class BaseQuantizer(nn.Module, StatefullModuleInterface, ABC): def __init__(self, qspec: PTQuantizerSpec): super().__init__() + self._qspec = qspec self._narrow_range = qspec.narrow_range self._signedness_to_force = qspec.signedness_to_force self._is_using_log_scale_storage = qspec.logarithm_scale @@ -563,6 +565,14 @@ def get_parameters_for_torch_fq(self) -> Tuple[int, int, torch.Tensor, torch.Ten zero_point - Quantizer zero point. """ + def get_config(self): + return self._qspec.get_state() + + @classmethod + def from_config(cls, state) -> "BaseQuantizer": + qsetup = PTQuantizerSpec.from_state(state) + return cls(qsetup) + class QuantizersSwitcher: """Enables/disables quantizers with saving and restoring original state""" diff --git a/nncf/torch/sparsity/layers.py b/nncf/torch/sparsity/layers.py index 5a506b87a10..bf3794cd716 100644 --- a/nncf/torch/sparsity/layers.py +++ b/nncf/torch/sparsity/layers.py @@ -8,29 +8,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import Any, Dict, List import torch from torch import nn from nncf.torch.layer_utils import COMPRESSION_MODULES +from nncf.torch.layer_utils import StatefullModuleInterface from nncf.torch.sparsity.functions import apply_binary_mask as apply_binary_mask_impl from nncf.torch.utils import is_tracing_state @COMPRESSION_MODULES.register() -class BinaryMask(nn.Module): +class BinaryMask(nn.Module, StatefullModuleInterface): + SHAPE_KEY = "shape" + def __init__(self, shape: List[int]): super().__init__() self.register_buffer("_binary_mask", torch.ones(shape)) self.frozen = False @property - def binary_mask(self): + def binary_mask(self) -> torch.Tensor: return self._binary_mask @binary_mask.setter - def binary_mask(self, tensor): + def binary_mask(self, tensor: torch.Tensor): with torch.no_grad(): self._binary_mask.set_(tensor) @@ -45,3 +48,10 @@ def _calc_training_binary_mask(self, weight): def apply_binary_mask(self, weight): return apply_binary_mask_impl(self.binary_mask, weight) + + def get_config(self) -> Dict[str, Any]: + return {self.SHAPE_KEY: list(self.binary_mask.shape)} + + @classmethod + def from_config(cls, state: Dict[str, Any]) -> "BinaryMask": + return BinaryMask(state[cls.SHAPE_KEY]) diff --git a/nncf/torch/sparsity/rb/layers.py b/nncf/torch/sparsity/rb/layers.py index c1df48ad563..8d0199046d9 100644 --- a/nncf/torch/sparsity/rb/layers.py +++ b/nncf/torch/sparsity/rb/layers.py @@ -8,20 +8,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import Any, Dict, List import torch from nncf.torch.functions import logit from nncf.torch.layer_utils import COMPRESSION_MODULES from nncf.torch.layer_utils import CompressionParameter +from nncf.torch.layer_utils import StatefullModuleInterface from nncf.torch.sparsity.layers import BinaryMask from nncf.torch.sparsity.rb.functions import binary_mask from nncf.torch.sparsity.rb.functions import calc_rb_binary_mask @COMPRESSION_MODULES.register() -class RBSparsifyingWeight(BinaryMask): +class RBSparsifyingWeight(BinaryMask, StatefullModuleInterface): + WEIGHTS_SHAPE_KEY = "weight_shape" + FROZEN_KEY = "frozen" + COMPRESSION_LR_MULTIPLIER_KEY = "compression_lr_multiplier" + EPS_KEY = "eps" + def __init__(self, weight_shape: List[int], frozen=True, compression_lr_multiplier=None, eps=1e-6): super().__init__(weight_shape) self.frozen = frozen @@ -36,11 +42,11 @@ def __init__(self, weight_shape: List[int], frozen=True, compression_lr_multipli self.mask_calculation_hook = MaskCalculationHook(self) @property - def mask(self): + def mask(self) -> torch.nn.Parameter: return self._mask @mask.setter - def mask(self, tensor): + def mask(self, tensor: torch.Tensor): self._mask.data = tensor self.binary_mask = binary_mask(self._mask) @@ -51,6 +57,23 @@ def _calc_training_binary_mask(self, weight): def loss(self): return binary_mask(self._mask) + def get_config(self) -> Dict[str, Any]: + return { + self.WEIGHTS_SHAPE_KEY: list(self.mask.shape), + self.FROZEN_KEY: self.frozen, + self.COMPRESSION_LR_MULTIPLIER_KEY: self.mask.compression_lr_multiplier, + self.EPS_KEY: self.eps, + } + + @classmethod + def from_config(cls, state: Dict[str, Any]) -> "RBSparsifyingWeight": + return RBSparsifyingWeight( + weight_shape=state[cls.WEIGHTS_SHAPE_KEY], + frozen=state[cls.FROZEN_KEY], + compression_lr_multiplier=state[cls.COMPRESSION_LR_MULTIPLIER_KEY], + eps=state[cls.EPS_KEY], + ) + class MaskCalculationHook: def __init__(self, module): diff --git a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_disable.json b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_disable.json index c1dc5e6731a..d462913193c 100644 --- a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_disable.json +++ b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_disable.json @@ -1,11 +1,19 @@ { "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[/nncf_model_input_0|OUTPUT]": { - "input_low": 0.0, - "input_high": 0.9970665574073792 + "input_low": [ + 0.0 + ], + "input_high": [ + 0.9970665574073792 + ] }, "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0|OUTPUT]": { - "input_low": -3.8243322372436523, - "input_high": 3.794454574584961 + "input_low": [ + -3.8243322372436523 + ], + "input_high": [ + 3.794454574584961 + ] }, "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0|INPUT1]": { "input_low": [ diff --git a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_enable.json b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_enable.json index c679824f317..6f60ba19e5b 100644 --- a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_enable.json +++ b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_enable.json @@ -1,11 +1,19 @@ { "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[/nncf_model_input_0|OUTPUT]": { - "input_low": 0.0, - "input_high": 0.9970665574073792 + "input_low": [ + 0.0 + ], + "input_high": [ + 0.9970665574073792 + ] }, "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0|OUTPUT]": { - "input_low": -3.8243322372436523, - "input_high": 3.794454574584961 + "input_low": [ + -3.8243322372436523 + ], + "input_high": [ + 3.794454574584961 + ] }, "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0|INPUT1]": { "input_low": [ diff --git a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_first_layer_only.json b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_first_layer_only.json index 708715926b7..89d8d054e81 100644 --- a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_first_layer_only.json +++ b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_first_layer_only.json @@ -1,11 +1,19 @@ { "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[/nncf_model_input_0|OUTPUT]": { - "input_low": 0.0, - "input_high": 0.9970665574073792 + "input_low": [ + 0.0 + ], + "input_high": [ + 0.9970665574073792 + ] }, "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0|OUTPUT]": { - "input_low": -3.8243322372436523, - "input_high": 3.794454574584961 + "input_low": [ + -3.8243322372436523 + ], + "input_high": [ + 3.794454574584961 + ] }, "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0|INPUT1]": { "input_low": [ diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index bf7e33cae13..aa8268ef1d1 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -43,6 +43,7 @@ from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.initialization import PTInitializingDataLoader from nncf.torch.initialization import register_default_init_args +from nncf.torch.layer_utils import StatefullModuleInterface from nncf.torch.layers import NNCF_MODULES_MAP from nncf.torch.model_creation import create_compressed_model from nncf.torch.module_operations import UpdateWeight @@ -215,24 +216,24 @@ def nz_bias_num(self): class TwoSharedConvTestModel(nn.Module): INPUT_SHAPE = [1, 1, 4, 4] NNCF_CONV_NODES_NAMES = [ - "TwoSharedConvTestModel/NNCFConv2d[conv1]/conv2d_0", - "TwoSharedConvTestModel/NNCFConv2d[conv2]/conv2d_0", + "TwoSharedConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0", + "TwoSharedConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0", ] CONV_NODES_NAMES = [ - "TwoSharedConvTestModel/Conv2d[conv1]/conv2d_0", - "TwoSharedConvTestModel/Conv2d[conv2]/conv2d_0", + "TwoSharedConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0", + "TwoSharedConvTestModel/Sequential[features]/Sequential[1]/Conv2d[0]/conv2d_0", ] def __init__(self): super().__init__() self.features = [] - self.conv1 = create_conv(1, 1, 1, -1, -2) - self.conv2 = create_conv(1, 1, 1, 0, 0) + self.features.append(nn.Sequential(create_conv(1, 1, 1, -1, -2))) + self.features.append(nn.Sequential(create_conv(1, 1, 1, 0, 0))) + self.features = nn.Sequential(*self.features) def forward(self, x): for _ in range(2): - x = self.conv1(x) - x = self.conv2(x) + x = self.features(x) return x @@ -265,24 +266,31 @@ def num_flat_features(self, x): return num_features -class DummyOpWithState(torch.nn.Module): +class DummyOpWithState(torch.nn.Module, StatefullModuleInterface): def __init__(self, state: str): super().__init__() self._state = state + # Keep dummy param to check state dict + self._dummy_param = torch.nn.Parameter( + torch.tensor( + 0.0, + ) + ) - def __call__(self, *args): + def forward(self, *args): if len(args) == 1: - return args[0] + return args[0] + self._dummy_param # To work correctly with # TargetType.PRE_LAYER_OPERATION # TargetType.POST_LAYER_OPERATION + args[0].weight + self._dummy_param return None - def get_state(self): + def get_config(self): return self._state @classmethod - def from_state(cls, state: str): + def from_config(cls, state: str): return cls(state) diff --git a/tests/torch/nncf_network/helpers.py b/tests/torch/nncf_network/helpers.py index 06805cd59b7..719166ffa9f 100644 --- a/tests/torch/nncf_network/helpers.py +++ b/tests/torch/nncf_network/helpers.py @@ -54,7 +54,9 @@ class InsertionCommandBuilder: Contains methods which allows to build all possible commands for the given torch.nn.Module. Target module should have NNCF_CONV_NODES_NAMES and CONV_NODES_NAMES with names of - target model convolutions and names of nncf-wrapped target model convolutions + target model convolutions and names of nncf-wrapped target model convolutions. + Convolutions should be placed inside nn.sequential in .features attribute + for test compatibility. """ AVAILABLE_MODELS = (TwoConvTestModel, TwoSharedConvTestModel) @@ -162,7 +164,7 @@ def get_all_available_commands( command_type, target_type ): continue - command = self._create_command( + command = self.create_one_command( command_builder, target_type, priority, @@ -185,7 +187,7 @@ def is_unsupported_by_transformer_command(command_type: PTTransformationCommand, ] @staticmethod - def _create_command( + def create_one_command( command_builder, target_type, priority, diff --git a/tests/torch/nncf_network/test_get_applied_modifications.py b/tests/torch/nncf_network/test_transformation_layout.py similarity index 100% rename from tests/torch/nncf_network/test_get_applied_modifications.py rename to tests/torch/nncf_network/test_transformation_layout.py diff --git a/tests/torch/ptq/helpers.py b/tests/torch/ptq/helpers.py index 7dd88540104..4047f892e21 100644 --- a/tests/torch/ptq/helpers.py +++ b/tests/torch/ptq/helpers.py @@ -20,7 +20,6 @@ from nncf.torch.graph.operator_metatypes import PTModuleDepthwiseConv2dSubtype from nncf.torch.graph.operator_metatypes import PTModuleLinearMetatype from nncf.torch.graph.operator_metatypes import PTSumMetatype -from nncf.torch.tensor_statistics.statistics import PTMinMaxTensorStatistic from tests.post_training.test_templates.models import NNCFGraphToTest from tests.post_training.test_templates.models import NNCFGraphToTestDepthwiseConv from tests.post_training.test_templates.models import NNCFGraphToTestSumAggregation @@ -81,15 +80,3 @@ def get_nncf_network(model: torch.nn.Module, input_shape: Optional[List[int]] = model = model.eval() device = next(model.named_parameters())[1].device return wrap_model(model, torch.ones(input_shape).to(device=device), trace_parameters=True) - - -def mock_collect_statistics(mocker): - _ = mocker.patch( - "nncf.common.tensor_statistics.aggregator.StatisticsAggregator.collect_statistics", return_value=None - ) - min_, max_ = 0.0, 1.0 - min_, max_ = torch.tensor(min_), torch.tensor(max_) - _ = mocker.patch( - "nncf.experimental.common.tensor_statistics.collectors.TensorCollector.get_statistics", - return_value=PTMinMaxTensorStatistic(min_values=min_, max_values=max_), - ) diff --git a/tests/torch/ptq/test_graphs.py b/tests/torch/ptq/test_graphs.py index ee59703d3ed..93281435104 100644 --- a/tests/torch/ptq/test_graphs.py +++ b/tests/torch/ptq/test_graphs.py @@ -24,7 +24,6 @@ from tests.post_training.test_templates.helpers import EmbeddingModel from tests.post_training.test_templates.helpers import get_static_dataset from tests.torch import test_models -from tests.torch.ptq.helpers import mock_collect_statistics from tests.torch.quantization.test_algo_quantization import SharedLayersModel from tests.torch.test_compressed_graph import ModelDesc from tests.torch.test_compressed_graph import check_graph @@ -95,15 +94,20 @@ def get_model_name(description): ("desc", "quantization_parameters"), TEST_MODELS_DESC, ids=[get_model_name(m) for m in TEST_MODELS_DESC] ) def test_min_max_classification_quantized_graphs(desc: ModelDesc, quantization_parameters, graph_dir, mocker): - mock_collect_statistics(mocker) model = desc.model_builder() nncf_network = wrap_model(model, torch.ones(desc.input_sample_sizes), trace_parameters=True) quantization_parameters["advanced_parameters"] = AdvancedQuantizationParameters(disable_bias_correction=True) + quantization_parameters["subset_size"] = 1 quantization_algorithm = PostTrainingQuantization(**quantization_parameters) + def transform_fn(input_) -> torch.Tensor: + return torch.tensor(input_[0]) + quantized_model = quantization_algorithm.apply( - nncf_network, nncf_network.nncf.get_graph(), dataset=get_static_dataset(desc.input_sample_sizes, None, None) + nncf_network, + nncf_network.nncf.get_graph(), + dataset=get_static_dataset(desc.input_sample_sizes, transform_fn, None), ) check_graph(quantized_model.nncf.get_graph(), desc.dot_filename(), graph_dir) diff --git a/tests/torch/test_serialization.py b/tests/torch/test_serialization.py new file mode 100644 index 00000000000..b6830383e9b --- /dev/null +++ b/tests/torch/test_serialization.py @@ -0,0 +1,332 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from copy import deepcopy + +import pytest +import torch + +from nncf.common.factory import ModelTransformerFactory +from nncf.common.quantization.structs import QuantizationScheme +from nncf.quantization.algorithms.smooth_quant.torch_backend import SQMultiply +from nncf.torch import wrap_model +from nncf.torch.graph.transformations.commands import PTTransformationCommand +from nncf.torch.graph.transformations.commands import TransformationType +from nncf.torch.graph.transformations.serialization import deserialize_command +from nncf.torch.graph.transformations.serialization import deserialize_transformations +from nncf.torch.graph.transformations.serialization import serialize_command +from nncf.torch.graph.transformations.serialization import serialize_transformations +from nncf.torch.module_operations import UpdateWeight +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.pruning.filter_pruning.layers import FilterPruningMask +from nncf.torch.quantization.layers import AsymmetricQuantizer +from nncf.torch.quantization.layers import BaseQuantizer +from nncf.torch.quantization.layers import PTQuantizerSpec +from nncf.torch.quantization.layers import SymmetricQuantizer +from nncf.torch.sparsity.layers import BinaryMask +from nncf.torch.sparsity.rb.layers import RBSparsifyingWeight +from tests.torch.helpers import DummyOpWithState +from tests.torch.helpers import TwoConvTestModel +from tests.torch.helpers import commands_are_equal +from tests.torch.nncf_network.helpers import AVAILABLE_TARGET_TYPES +from tests.torch.nncf_network.helpers import InsertionCommandBuilder + + +def load_from_config_impl(model: torch.nn.Module, serialized_transformations, example_input, trace_parameters): + """ + Test implementation of nncf.torch.load_from_config(). Should be replaced by the implementation + """ + transformations_layout = deserialize_transformations(serialized_transformations) + + nncf_network = wrap_model(deepcopy(model), example_input=example_input, trace_parameters=trace_parameters) + transformed_model = ModelTransformerFactory.create(nncf_network).transform(transformations_layout) + + transformed_model.nncf.disable_dynamic_graph_building() + return transformed_model + + +def nncf_get_config_impl( + model: NNCFNetwork, +): + """ + Test implementation of model.nncf.get_config(). Should be replaced by the implementation + """ + layout = model.nncf.transformation_layout() + return serialize_transformations(layout) + + +@pytest.mark.parametrize("target_type", AVAILABLE_TARGET_TYPES) +@pytest.mark.parametrize("command_builder", InsertionCommandBuilder(TwoConvTestModel).get_command_builders()) +@pytest.mark.parametrize("priority", InsertionCommandBuilder.PRIORITIES) +def test_serialize_load_command(target_type, command_builder, priority): + dummy_op_state = "DUMMY_OP_STATE" + op_unique_name = "UNIQUE_NAME" + # The only difference for trace_parameters param in this test is taget nodes names + command = InsertionCommandBuilder(TwoConvTestModel).create_one_command( + command_builder[0], target_type, priority, dummy_op_state, trace_parameters=False, op_unique_name=op_unique_name + ) + + serialized_command = serialize_command(command) + + # Check serialized transformation are json compatible + j_str = json.dumps(serialized_command) + serialized_command = json.loads(j_str) + + recovered_command = deserialize_command(serialized_command) + _check_commands_after_serialization(command, recovered_command, dummy_op_state) + + +def test_non_supported_command_serialization(): + class NonSupportedCommand(PTTransformationCommand): + def __init__(self): + super().__init__(TransformationType.INSERT, None) + + command = NonSupportedCommand() + + with pytest.raises(RuntimeError): + serialize_command(command) + + serialized_command = {"type": NonSupportedCommand.__name__} + with pytest.raises(RuntimeError): + deserialize_command(serialized_command) + + +def test_serialize_transformations(): + dummy_op_state = "DUMMY_OP_STATE" + # The only difference for trace_parameters param in this test is taget nodes names + layout = InsertionCommandBuilder(TwoConvTestModel).get_all_available_commands( + dummy_op_state=dummy_op_state, trace_parameters=False + ) + + serialized_transformations = serialize_transformations(layout) + + # Check serialized transformation are json compatible + j_str = json.dumps(serialized_transformations) + serialized_transformations = json.loads(j_str) + + recovered_layout = deserialize_transformations(serialized_transformations) + assert len(layout.transformations) == len(recovered_layout.transformations) + # Can zip layouts because the order should not be altered + for command, recovered_command in zip(layout.transformations, recovered_layout.transformations): + _check_commands_after_serialization(command, recovered_command, dummy_op_state) + + +@pytest.mark.parametrize("model_cls", InsertionCommandBuilder.AVAILABLE_MODELS) +@pytest.mark.parametrize("trace_parameters", (False, True)) +def test_get_apply_serialization_from_a_model(model_cls, trace_parameters): + dummy_op_state = "DUMMY_OP_STATE" + layout = InsertionCommandBuilder(model_cls).get_all_available_commands( + dummy_op_state, trace_parameters, skip_model_transformer_unsupported=True + ) + model = model_cls() + example_input = torch.ones((1, 1, 4, 4)) + nncf_model = wrap_model(deepcopy(model), example_input=example_input, trace_parameters=trace_parameters) + modified_model = ModelTransformerFactory.create(nncf_model).transform(layout) + + serialized_transformations = nncf_get_config_impl(modified_model) + + # Check serialized transformation are json compatible + j_str = json.dumps(serialized_transformations) + serialized_transformations = json.loads(j_str) + + recovered_model = load_from_config_impl(model, serialized_transformations, example_input, trace_parameters) + + assert modified_model.state_dict().keys() == recovered_model.state_dict().keys() + if not trace_parameters: + _check_pre_post_ops(modified_model, recovered_model) + + context = modified_model.nncf._compressed_context + recovered_context = recovered_model.nncf._compressed_context + for hooks_attr in ["_pre_hooks", "_post_hooks"]: + container = getattr(context, hooks_attr) + recovered_container = getattr(recovered_context, hooks_attr) + assert len(container) == len(recovered_container) + for op_address, hooks in container.items(): + recovered_hooks = recovered_container[op_address] + for k, hook in hooks.items(): + recovered_hook = recovered_hooks[k] + _check_hook_are_equal(hook, recovered_hook) + + for attr_name in ["external_quantizers", "external_op"]: + container = getattr(modified_model.nncf, attr_name) + recovered_container = getattr(recovered_model.nncf, attr_name) + assert len(container) == len(recovered_container) + for k, module in container.items(): + recovered_module = recovered_container[k] + _check_hook_are_equal(module, recovered_module) + + +def _check_pre_post_ops(modified_model, recovered_model): + for conv, recovered_conv in zip(modified_model.features, recovered_model.features): + for hooks_attr in ["pre_ops", "post_ops"]: + hooks = getattr(conv[0], hooks_attr) + recovered_hooks = getattr(recovered_conv[0], hooks_attr) + assert len(hooks) == len(recovered_hooks) + for k, hook in hooks.items(): + recovered_hook = recovered_hooks[k] + if isinstance(hook, UpdateWeight): + assert isinstance(recovered_hook, UpdateWeight) + hook = hook.op + recovered_hook = recovered_hook.op + _check_hook_are_equal(hook, recovered_hook) + + +def _check_hook_are_equal(hook, recovered_hook): + assert type(hook) == type(recovered_hook) + if isinstance(hook, DummyOpWithState): + assert hook.get_config() == recovered_hook.get_config() + return + # Hook is external op call hook then + assert hook._storage_name == recovered_hook._storage_name + assert hook._storage_key == recovered_hook._storage_key + + +def _check_commands_after_serialization(command, recovered_command, dummy_op_state=None): + commands_are_equal(recovered_command, command, check_fn_ref=False) + assert isinstance(command.fn, DummyOpWithState) + assert command.fn.get_config() == recovered_command.fn.get_config() + if dummy_op_state is not None: + assert command.fn.get_config() == dummy_op_state + + +@pytest.mark.parametrize("size", (4, [3, 4])) +def test_pruning_mask_serialization(size): + node_name = "dummy_node_name" + dim = 2 + mask = FilterPruningMask(size=size, node_name=node_name, dim=dim) + mask.binary_filter_pruning_mask = torch.fill(torch.empty(size), 5) + state_dict = mask.state_dict() + + state = mask.get_config() + json_state = json.dumps(state) + state = json.loads(json_state) + + recovered_mask = FilterPruningMask.from_config(state) + recovered_mask.load_state_dict(state_dict) + + ref_size = size if isinstance(size, list) else [size] + assert list(recovered_mask.binary_filter_pruning_mask.size()) == ref_size + assert recovered_mask.node_name == node_name + assert recovered_mask.mask_applying_dim == dim + + assert torch.all(mask.binary_filter_pruning_mask == recovered_mask.binary_filter_pruning_mask) + + +@pytest.mark.parametrize("quantizer_class", (SymmetricQuantizer, AsymmetricQuantizer)) +def test_quantizer_serialization(quantizer_class: BaseQuantizer): + scale_shape = [1, 3, 1, 1] + ref_qspec = PTQuantizerSpec( + num_bits=4, + mode=QuantizationScheme.ASYMMETRIC, + signedness_to_force=False, + narrow_range=True, + half_range=False, + scale_shape=scale_shape, + logarithm_scale=False, + is_quantized_on_export=False, + compression_lr_multiplier=2.0, + ) + quantizer = quantizer_class(ref_qspec) + if isinstance(quantizer, SymmetricQuantizer): + quantizer.scale = torch.nn.Parameter(torch.fill(torch.empty(scale_shape), 5)) + elif isinstance(quantizer, AsymmetricQuantizer): + quantizer.input_low = torch.nn.Parameter(torch.fill(torch.empty(scale_shape), 6)) + quantizer.input_range = torch.nn.Parameter(torch.fill(torch.empty(scale_shape), 7)) + + state_dict = quantizer.state_dict() + + state = quantizer.get_config() + json_state = json.dumps(state) + state = json.loads(json_state) + + recovered_quantizer = quantizer_class.from_config(state) + recovered_quantizer.load_state_dict(state_dict) + + assert recovered_quantizer._qspec == ref_qspec + + assert torch.all(quantizer._num_bits == recovered_quantizer._num_bits) + assert torch.all(quantizer.enabled == recovered_quantizer.enabled) + if isinstance(quantizer, SymmetricQuantizer): + assert torch.all(quantizer.signed_tensor == recovered_quantizer.signed_tensor) + assert torch.all(quantizer.scale == recovered_quantizer.scale) + elif isinstance(quantizer, AsymmetricQuantizer): + assert torch.all(quantizer.input_low == recovered_quantizer.input_low) + assert torch.all(quantizer.input_range == recovered_quantizer.input_range) + else: + raise RuntimeError() + + +def test_sparsity_binary_mask_serialization(): + ref_shape = [4, 2, 1, 3] + mask = BinaryMask(ref_shape) + mask.binary_mask = torch.zeros(ref_shape) + state_dict = mask.state_dict() + + state = mask.get_config() + json_state = json.dumps(state) + state = json.loads(json_state) + + recovered_mask = BinaryMask.from_config(state) + recovered_mask.load_state_dict(state_dict) + + assert list(recovered_mask.binary_mask.shape) == ref_shape + assert torch.all(mask.binary_mask == recovered_mask.binary_mask) + + +def test_rb_sparsity_mask_serialization(): + ref_weights_shape = [3, 2, 4, 1] + ref_frozen = False + ref_compression_lr_multiplier = 2.0 + ref_eps = 0.3 + mask = RBSparsifyingWeight( + weight_shape=ref_weights_shape, + frozen=ref_frozen, + compression_lr_multiplier=ref_compression_lr_multiplier, + eps=ref_eps, + ) + mask.binary_mask = torch.zeros(ref_weights_shape) + mask.mask = torch.fill(torch.empty(ref_weights_shape), 5) + state_dict = mask.state_dict() + + state = mask.get_config() + json_state = json.dumps(state) + state = json.loads(json_state) + + recovered_mask = RBSparsifyingWeight.from_config(state) + recovered_mask.load_state_dict(state_dict) + + assert list(recovered_mask.mask.shape) == ref_weights_shape + assert recovered_mask.frozen == ref_frozen + assert recovered_mask.mask.compression_lr_multiplier == ref_compression_lr_multiplier + assert recovered_mask.eps == ref_eps + + assert torch.all(mask.mask == recovered_mask.mask) + assert torch.all(mask.binary_mask == recovered_mask.binary_mask) + assert torch.all(mask.uniform == recovered_mask.uniform) + + +def test_sq_multiply_serialization(): + tensor_shape = [1, 3, 5] + tensor_value = torch.fill(torch.empty(tensor_shape, dtype=torch.float16), 5) + sq_multiply = SQMultiply(tensor_shape) + sq_multiply.scale = tensor_value + state_dict = sq_multiply.state_dict() + + state = sq_multiply.get_config() + json_state = json.dumps(state) + state = json.loads(json_state) + + recovered_sq_multiply = SQMultiply.from_config(state) + recovered_sq_multiply.load_state_dict(state_dict) + + assert torch.all(sq_multiply.scale == recovered_sq_multiply.scale) + assert sq_multiply.scale.shape == recovered_sq_multiply.scale.shape