diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index fa1f2b0cbb6..d5a37ddcbe5 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple, Union import torch @@ -36,7 +36,9 @@ from nncf.quantization.range_estimator import RangeEstimatorParameters from nncf.torch.graph.graph import PTNNCFGraph from nncf.torch.graph.graph import PTTargetPoint -from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand +from nncf.torch.graph.transformations.command_creation import create_quantizer_insertion_command +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.hardware.config import PTHWConfig from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT @@ -128,17 +130,6 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) - target_type = PTMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type] return PTTargetPoint(target_type, target_node_name, input_port_id=port_id) - @staticmethod - def create_quantizer_insertion_command( - nncf_graph: NNCFGraph, - target_point: PTTargetPoint, - quantizer_config: QuantizerConfig, - parameters: FakeQuantizeParameters, - ) -> PTQuantizerInsertionCommand: - return PTMinMaxAlgoBackend._create_quantizer_insertion_command( - nncf_graph, target_point, quantizer_config, parameters - ) - @staticmethod def create_convert_insertion_command( target_point: PTTargetPoint, @@ -290,12 +281,12 @@ def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantiz quantizer.scale = torch.nn.Parameter(parameters.input_high.data - quantizer.eps) @staticmethod - def _create_quantizer_insertion_command( + def create_quantizer_insertion_command( nncf_graph: NNCFGraph, target_point: PTTargetPoint, quantizer_config: QuantizerConfig, parameters: FakeQuantizeParameters, - ) -> PTQuantizerInsertionCommand: + ) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]: _, scale_shape, _ = PTMinMaxAlgoBackend._get_input_scale_shape( nncf_graph, target_point, quantizer_config.per_channel ) @@ -303,7 +294,7 @@ def _create_quantizer_insertion_command( quantizer = PTMinMaxAlgoBackend._create_quantizer( quantizer_config, scale_shape, parameters, target_point.target_type ) - return PTQuantizerInsertionCommand(target_point, quantizer) + return create_quantizer_insertion_command(target_point, quantizer) @staticmethod def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[OperatorMetatype]: diff --git a/nncf/torch/external_hook.py b/nncf/torch/external_hook.py index 7983e74da7e..60902afbbe2 100644 --- a/nncf/torch/external_hook.py +++ b/nncf/torch/external_hook.py @@ -11,7 +11,7 @@ from typing import Any -from nncf.torch.dynamic_graph.context import TracingContext +from nncf.torch.dynamic_graph.context import get_current_context EXTERNAL_OP_STORAGE_NAME = "external_op" @@ -26,17 +26,15 @@ class ExternalOpCallHook: the base module execution. """ - def __init__(self, storage_name: str, context: TracingContext, storage_key: str): + def __init__(self, storage_name: str, storage_key: str): """ :param storage_name: Attribute name of a model NNCFInterface. - :param context: Current tracing context. :param storage_key: Key to retrieve callable hook """ self._storage_name = storage_name - self._compressed_context = context self._storage_key = storage_key def __call__(self, *args: Any, **kwargs) -> Any: - replica = self._compressed_context.base_module_thread_local_replica + replica = get_current_context().base_module_thread_local_replica storage = getattr(replica.nncf, self._storage_name) return storage[self._storage_key](*args, **kwargs) diff --git a/nncf/torch/graph/transformations/command_creation.py b/nncf/torch/graph/transformations/command_creation.py index ac52802c039..16b14c3e172 100644 --- a/nncf/torch/graph/transformations/command_creation.py +++ b/nncf/torch/graph/transformations/command_creation.py @@ -9,13 +9,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + from torch import Tensor from nncf.common.graph.graph import NNCFNode from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.quantization.structs import NonWeightQuantizerId +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand +from nncf.torch.quantization.layers import BaseQuantizer def create_bias_correction_command(node: NNCFNode, bias_value: Tensor) -> PTBiasCorrectionCommand: @@ -40,3 +48,20 @@ def create_command_to_update_weight(node: NNCFNode, weight_value: Tensor) -> PTW """ target_point = PTTargetPoint(TargetType.LAYER, node.node_name) return PTWeightUpdateCommand(target_point, weight_value) + + +def create_quantizer_insertion_command( + target_point: PTTargetPoint, quantizer: BaseQuantizer +) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]: + if target_point.type is TargetType.OPERATION_WITH_WEIGHTS: + return PTInsertionCommand(target_point, quantizer, TransformationPriority.QUANTIZATION_PRIORITY) + + quantizer_id = NonWeightQuantizerId(target_point.target_node_name, target_point.input_port_id) + storage_key = str(quantizer_id) + return PTSharedFnInsertionCommand( + target_points=[target_point], + fn=quantizer, + op_unique_name=storage_key, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, + priority=TransformationPriority.QUANTIZATION_PRIORITY, + ) diff --git a/nncf/torch/graph/transformations/commands.py b/nncf/torch/graph/transformations/commands.py index c7793f27a28..4e27236f5e9 100644 --- a/nncf/torch/graph/transformations/commands.py +++ b/nncf/torch/graph/transformations/commands.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum from typing import Any, Callable, Dict, List import torch @@ -150,12 +151,18 @@ def requires_graph_rebuild(self): return self.priority == TransformationPriority.QUANTIZATION_PRIORITY +class ExtraCompressionModuleType(Enum): + EXTERNAL_QUANTIZER = 0 + EXTERNAL_OP = 1 + + class PTSharedFnInsertionCommand(PTTransformationCommand): def __init__( self, target_points: List[PTTargetPoint], fn: Callable, op_unique_name: str, + compression_module_type: ExtraCompressionModuleType = ExtraCompressionModuleType.EXTERNAL_OP, priority: TransformationPriority = TransformationPriority.DEFAULT_PRIORITY, hooks_group_name: str = DEFAULT_HOOKS_GROUP_NAME, ): @@ -163,6 +170,7 @@ def __init__( self.target_points = target_points self.fn = fn self.op_name = op_unique_name + self.compression_module_type = compression_module_type self.priority = priority self.hooks_group_name = hooks_group_name @@ -170,25 +178,6 @@ def requires_graph_rebuild(self): return True -class PTQuantizerInsertionCommand(PTTransformationCommand): - """ - Insertion quantizer operation to the models. - """ - - def __init__( - self, - point: PTTargetPoint, - quantizer: "BaseQuantizer", # noqa: F821 - hooks_group_name: str = DEFAULT_HOOKS_GROUP_NAME, - ): - super().__init__(TransformationType.INSERT, point) - self.quantizer = quantizer - self.hooks_group_name = hooks_group_name - - def requires_graph_rebuild(self): - return True - - class PTModelExtractionWithFusedBiasCommand(PTCommand): """ Extracts sequence by name with node that contain fused bias. diff --git a/nncf/torch/model_graph_manager.py b/nncf/torch/model_graph_manager.py index 83a0f02b6d4..ad788520705 100644 --- a/nncf/torch/model_graph_manager.py +++ b/nncf/torch/model_graph_manager.py @@ -18,10 +18,11 @@ from nncf.common.graph.graph import NNCFNode from nncf.common.graph.operator_metatypes import CONST_NOOP_METATYPES from nncf.torch.dynamic_graph.context import PreHookId +from nncf.torch.external_hook import ExternalOpCallHook from nncf.torch.graph import operator_metatypes as om from nncf.torch.nncf_network import NNCFNetwork -from nncf.torch.quantization.external_quantizer import ExternalQuantizerCallHook from nncf.torch.quantization.layers import AsymmetricQuantizer +from nncf.torch.quantization.layers import BaseQuantizer from nncf.torch.quantization.layers import SymmetricQuantizer CONV_META_TYPES = [ @@ -295,7 +296,9 @@ def get_fake_quantizer( hook_container = model.nncf._compressed_context._post_hooks.get(op_addr, {}) for call_hook in hook_container.values(): - if isinstance(call_hook, ExternalQuantizerCallHook): + if isinstance(call_hook, ExternalOpCallHook): storage = getattr(model.nncf, call_hook._storage_name) - return storage[call_hook._storage_key] + module = storage[call_hook._storage_key] + if isinstance(module, BaseQuantizer): + return module return None diff --git a/nncf/torch/model_transformer.py b/nncf/torch/model_transformer.py index 88f1cc101df..19c2c647b05 100644 --- a/nncf/torch/model_transformer.py +++ b/nncf/torch/model_transformer.py @@ -11,8 +11,10 @@ import copy from collections import defaultdict -from typing import Callable, Dict, List, Tuple +from functools import partial +from typing import Callable, Dict, List, Optional, Tuple +import torch from torch import Tensor from torch import nn from torch.nn.parameter import Parameter @@ -20,23 +22,20 @@ from nncf.common.graph.model_transformer import ModelTransformer from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.commands import TransformationPriority -from nncf.common.quantization.structs import NonWeightQuantizerId -from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand from nncf.torch.graph.transformations.commands import PTInsertionCommand from nncf.torch.graph.transformations.commands import PTModelExtractionWithFusedBiasCommand -from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.model_analyzer import get_potential_fused_node from nncf.torch.module_operations import UpdateWeight -from nncf.torch.nncf_network import ExtraCompressionModuleType from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.nncf_network import PTInsertionPoint +from nncf.torch.nncf_network import compression_module_type_to_attr_name from nncf.torch.quantization.external_quantizer import ExternalOpCallHook -from nncf.torch.quantization.external_quantizer import ExternalQuantizerCallHook from nncf.torch.utils import get_model_device from nncf.torch.utils import is_multidevice @@ -49,12 +48,15 @@ class PTModelTransformer(ModelTransformer): def __init__(self, model: NNCFNetwork): super().__init__(model) + device = None + if not is_multidevice(model): + device = get_model_device(model) + self._command_transformation_ordered_pairs = [ (PTModelExtractionWithFusedBiasCommand, self._apply_extraction_with_fused_bias_transformations), - (PTInsertionCommand, self._apply_insertion_transformations), - (PTQuantizerInsertionCommand, self._apply_quantizer_insertion_transformations), + (PTInsertionCommand, partial(self._apply_insertion_transformations, device=device)), + (PTSharedFnInsertionCommand, partial(self._apply_shared_nodes_insertion, device=device)), (PTBiasCorrectionCommand, self._apply_bias_correction_transformations), - (PTSharedFnInsertionCommand, self._apply_shared_nodes_insertion), (PTWeightUpdateCommand, self._apply_weights_update_transformations), ] @@ -78,12 +80,16 @@ def transform(self, transformation_layout: PTTransformationLayout) -> NNCFNetwor return model @staticmethod - def _apply_insertion_transformations(model: NNCFNetwork, transformations: List[PTInsertionCommand]) -> NNCFNetwork: + def _apply_insertion_transformations( + model: NNCFNetwork, transformations: List[PTInsertionCommand], device: Optional[torch.device] + ) -> NNCFNetwork: """ Applies insertion transformations to the model. :param model: Model to apply transformations. :param transformations: List of the bias correction transformations. + :param device: Target device for the insertion functions. Applies only to + functions which are subclassed from torch.nn.Module. Do nothing in case device is None. :return: A modified NNCFNetwork. """ node_to_op_address_mapping = model.nncf.get_node_to_op_address_mapping() @@ -98,7 +104,11 @@ def _apply_insertion_transformations(model: NNCFNetwork, transformations: List[P input_port_id=target_point.input_port_id, replaced_modules=model.nncf.replace_modules, ) + fn = transformation_command.fn + if device is not None and isinstance(fn, torch.nn.Module): + fn.to(device) + if model.nncf.replace_modules and target_point.type is TargetType.OPERATION_WITH_WEIGHTS: fn = UpdateWeight(fn) tup = (fn, transformation_command) @@ -113,21 +123,63 @@ def _apply_insertion_transformations(model: NNCFNetwork, transformations: List[P @staticmethod def _apply_shared_nodes_insertion( - model: NNCFNetwork, transformations: List[PTSharedFnInsertionCommand] + model: NNCFNetwork, + transformations: List[PTSharedFnInsertionCommand], + device: Optional[torch.device], ) -> NNCFNetwork: - compression_model_type = ExtraCompressionModuleType.EXTERNAL_OP + """ + Applies insertion of PTSharedFnInsertionCommand commands. For each command method inserts + a torch module to the NNCFNetwork and inserts call hooks for each command target points. + + :param model: Model to apply transformations. + :param transformations: List of the bias correction transformations. + :param device: Target device for the insertion functions. Applies only to + functions which are subclassed from torch.nn.Module. Do nothing in case device is None. + :return: A modified NNCFNetwork. + """ + compression_type_vs_transformations = defaultdict(list) + for transformation in transformations: + compression_type_vs_transformations[transformation.compression_module_type].append(transformation) + + for compression_module_type, transformations in compression_type_vs_transformations.items(): + model = PTModelTransformer._apply_shared_node_insertion_with_compression_type( + model, transformations, device, compression_module_type + ) + return model + + @staticmethod + def _apply_shared_node_insertion_with_compression_type( + model: NNCFNetwork, + transformations: List[PTSharedFnInsertionCommand], + device: Optional[torch.device], + compression_module_type: ExtraCompressionModuleType, + ): + """ + Does _apply_shared_nodes_insertion with specified compression model type which will be + used for each transformation command. - if not model.nncf.is_compression_module_registered(compression_model_type): - model.nncf.register_compression_module_type(compression_model_type) + :param model: Model to apply transformations. + :param transformations: List of the bias correction transformations. + :param device: Target device for the insertion functions. Applies only to + functions which are subclassed from torch.nn.Module. Do nothing in case device is None. + :param compression_module_type: Common compression module type for all commands. + :return: A modified NNCFNetwork. + """ + if not model.nncf.is_compression_module_registered(compression_module_type): + model.nncf.register_compression_module_type(compression_module_type) insertion_commands: List[PTInsertionCommand] = [] for shared_command in transformations: - model.nncf.add_compression_module(shared_command.op_name, shared_command.fn, compression_model_type) + fn = shared_command.fn + if device is not None: + fn.to(device) + + model.nncf.add_compression_module(shared_command.op_name, fn, compression_module_type) for target_point in shared_command.target_points: fn = ExternalOpCallHook( - EXTERNAL_OP_STORAGE_NAME, model.nncf.get_tracing_context(), shared_command.op_name + compression_module_type_to_attr_name(compression_module_type), shared_command.op_name ) insertion_commands.append( PTInsertionCommand( @@ -138,47 +190,7 @@ def _apply_shared_nodes_insertion( ) ) - return PTModelTransformer._apply_insertion_transformations(model, insertion_commands) - - @staticmethod - def _apply_quantizer_insertion_transformations( - model: NNCFNetwork, transformations: List[PTQuantizerInsertionCommand] - ) -> NNCFNetwork: - """ - Applies quantizer insertion transformations on the model. - - :param model: Model to apply transformations. - :param transformations: List of the OVQuantizerInsertionCommand transformations. - :return: Model with inserted FakeQuantize nodes. - """ - compression_model_type = ExtraCompressionModuleType.EXTERNAL_QUANTIZER - - if not model.nncf.is_compression_module_registered(compression_model_type): - model.nncf.register_compression_module_type(compression_model_type) - - insertion_commands: List[PTInsertionCommand] = [] - device = None - if not is_multidevice(model): - device = get_model_device(model) - - for transformation_command in transformations: - target_point: PTTargetPoint = transformation_command.target_point - quantizer_module = transformation_command.quantizer - if device is not None: - quantizer_module = quantizer_module.to(device) - fn = quantizer_module - - if target_point.type is not TargetType.OPERATION_WITH_WEIGHTS: - quantizer_id = NonWeightQuantizerId(target_point.target_node_name, target_point.input_port_id) - storage_key = str(quantizer_id) - model.nncf.add_compression_module(storage_key, quantizer_module, compression_model_type) - fn = ExternalQuantizerCallHook(model.nncf.get_tracing_context(), storage_key) - - insertion_commands.append( - PTInsertionCommand(target_point, fn, TransformationPriority.QUANTIZATION_PRIORITY) - ) - - return PTModelTransformer._apply_insertion_transformations(model, insertion_commands) + return PTModelTransformer._apply_insertion_transformations(model, insertion_commands, device) @staticmethod def _apply_extraction_with_fused_bias_transformations( diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index 5494e3a5620..a27d338a77a 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -17,7 +17,6 @@ from contextlib import contextmanager from copy import deepcopy from dataclasses import dataclass -from enum import Enum from enum import IntEnum from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypeVar @@ -67,6 +66,7 @@ from nncf.torch.graph.operator_metatypes import OPERATORS_WITH_WEIGHTS_METATYPES from nncf.torch.graph.operator_metatypes import PTSplitMetatype from nncf.torch.graph.transformations.commands import DEFAULT_HOOKS_GROUP_NAME +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.knowledge_distillation.knowledge_distillation_handler import KnowledgeDistillationLossHandler from nncf.torch.layer_utils import _NNCFModuleMixin @@ -142,11 +142,6 @@ def __hash__(self): return hash(str(self)) -class ExtraCompressionModuleType(Enum): - EXTERNAL_QUANTIZER = 0 - EXTERNAL_OP = 1 - - @dataclass class PTGraphPair: """ @@ -576,7 +571,7 @@ def is_scope_in_nncf_module_scope(self, scope: Scope) -> bool: return False def register_compression_module_type(self, compression_module_type: ExtraCompressionModuleType): - attr_name = self._compression_module_type_to_attr_name(compression_module_type) + attr_name = compression_module_type_to_attr_name(compression_module_type) if compression_module_type in self._extra_module_types: raise nncf.ValidationError(f"Module type {compression_module_type} is already registered") @@ -586,7 +581,7 @@ def register_compression_module_type(self, compression_module_type: ExtraCompres def add_compression_module( self, module_key: str, module: nn.Module, compression_module_type: ExtraCompressionModuleType ): - attr_name = self._compression_module_type_to_attr_name(compression_module_type) + attr_name = compression_module_type_to_attr_name(compression_module_type) if compression_module_type not in self._extra_module_types: raise nncf.InternalError(f"Module type {compression_module_type} was not registered") storage = self.__getattr__(attr_name) @@ -595,7 +590,7 @@ def add_compression_module( storage[module_key] = module def get_compression_modules_by_type(self, compression_module_type: ExtraCompressionModuleType) -> nn.ModuleDict: - attr_name = self._compression_module_type_to_attr_name(compression_module_type) + attr_name = compression_module_type_to_attr_name(compression_module_type) if compression_module_type not in self._extra_module_types: raise nncf.InternalError(f"Module type {compression_module_type} was not registered") return self.__getattr__(attr_name) @@ -609,20 +604,8 @@ def is_compression_module_registered(self, compression_module_type: ExtraCompres """ return compression_module_type in self._extra_module_types - @staticmethod - def _compression_module_type_to_attr_name(compression_module_type: ExtraCompressionModuleType): - """ - Required for backward compatibility with checkpoints that store function and activation - quantizers directly under corresponding attributes of NNCFNetwork. - """ - if compression_module_type == ExtraCompressionModuleType.EXTERNAL_QUANTIZER: - return EXTERNAL_QUANTIZERS_STORAGE_NAME - if compression_module_type == ExtraCompressionModuleType.EXTERNAL_OP: - return EXTERNAL_OP_STORAGE_NAME - raise nncf.ValidationError("Unknown extra module type") - def sort_compression_modules(self, compression_module_type: ExtraCompressionModuleType): - attr_name = self._compression_module_type_to_attr_name(compression_module_type) + attr_name = compression_module_type_to_attr_name(compression_module_type) if compression_module_type not in self._extra_module_types: raise nncf.InternalError("Module type {} was not registered".format(compression_module_type)) module_dict = self.__getattr__(attr_name) @@ -1137,3 +1120,15 @@ def hook_fn( def close(self): self.hook.remove() + + +def compression_module_type_to_attr_name(compression_module_type: ExtraCompressionModuleType): + """ + Required for backward compatibility with checkpoints that store function and activation + quantizers directly under corresponding attributes of NNCFNetwork. + """ + if compression_module_type == ExtraCompressionModuleType.EXTERNAL_QUANTIZER: + return EXTERNAL_QUANTIZERS_STORAGE_NAME + if compression_module_type == ExtraCompressionModuleType.EXTERNAL_OP: + return EXTERNAL_OP_STORAGE_NAME + raise nncf.ValidationError("Unknown extra module type") diff --git a/nncf/torch/quantization/algo.py b/nncf/torch/quantization/algo.py index 278365d8289..24ef0576e9b 100644 --- a/nncf/torch/quantization/algo.py +++ b/nncf/torch/quantization/algo.py @@ -83,6 +83,7 @@ from nncf.torch.graph.operator_metatypes import PTCatMetatype from nncf.torch.graph.operator_metatypes import PTDepthwiseConv2dSubtype from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.graph.transformations.commands import PTInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.graph.transformations.commands import TransformationPriority @@ -90,7 +91,6 @@ from nncf.torch.hardware.config import PTHWConfig from nncf.torch.initialization import SimpleDataLoaderRunner from nncf.torch.module_operations import UpdatePaddingValue -from nncf.torch.nncf_network import ExtraCompressionModuleType from nncf.torch.nncf_network import LoadStateListener from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.adjust_padding import AdjustPaddingArgs @@ -1208,15 +1208,11 @@ def is_weights(ip: PTTargetPoint) -> bool: # share the single module and this would be impossible for multiple weight quantizer sharing if # the corresponding UpdateWeights operations contained real modules (these would simply get copied # by PyTorch internals) - callable_obj = ExternalQuantizerCallHook( - target_model.nncf.get_tracing_context(), external_quantizer_storage_key, self._debug_interface - ) + callable_obj = ExternalQuantizerCallHook(external_quantizer_storage_key, self._debug_interface) else: # Hooks will be identical for each affected op_address in the linked scenario # - will call one and the same quantizer - callable_obj = ExternalQuantizerCallHook( - target_model.nncf.get_tracing_context(), external_quantizer_storage_key, self._debug_interface - ) + callable_obj = ExternalQuantizerCallHook(external_quantizer_storage_key, self._debug_interface) nncf_logger.debug( f"Performing " diff --git a/nncf/torch/quantization/debug_interface.py b/nncf/torch/quantization/debug_interface.py index 60798db1225..b46759ba466 100644 --- a/nncf/torch/quantization/debug_interface.py +++ b/nncf/torch/quantization/debug_interface.py @@ -59,7 +59,7 @@ def __init__(self): self._strict_forward = False def init_actual(self, owner_model: NNCFNetwork): - from nncf.torch.nncf_network import ExtraCompressionModuleType + from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType quantization_types = [class_type.__name__ for class_type in QUANTIZATION_MODULES.registry_dict.values()] quantizers_in_nncf_modules = owner_model.nncf.get_modules_in_nncf_modules_by_type(quantization_types) diff --git a/nncf/torch/quantization/external_quantizer.py b/nncf/torch/quantization/external_quantizer.py index 128441a7eab..7df7d20f994 100644 --- a/nncf/torch/quantization/external_quantizer.py +++ b/nncf/torch/quantization/external_quantizer.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nncf.torch.dynamic_graph.context import TracingContext from nncf.torch.external_hook import ExternalOpCallHook from nncf.torch.quantization.debug_interface import QuantizationDebugInterface @@ -25,11 +24,10 @@ class ExternalQuantizerCallHook(ExternalOpCallHook): def __init__( self, - context: TracingContext, quantizer_storage_key: str, debug_interface: QuantizationDebugInterface = None, ): - super().__init__(EXTERNAL_QUANTIZERS_STORAGE_NAME, context, quantizer_storage_key) + super().__init__(EXTERNAL_QUANTIZERS_STORAGE_NAME, quantizer_storage_key) self.debug_interface = debug_interface def __call__(self, *args, **kwargs): diff --git a/nncf/torch/quantization/precision_init/base_init.py b/nncf/torch/quantization/precision_init/base_init.py index ef4404c0944..9a1581ebf93 100644 --- a/nncf/torch/quantization/precision_init/base_init.py +++ b/nncf/torch/quantization/precision_init/base_init.py @@ -18,8 +18,8 @@ from nncf.common.quantization.structs import QuantizerId from nncf.common.quantization.structs import WeightQuantizerId from nncf.torch.dynamic_graph.scope import Scope +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.module_operations import UpdateWeight -from nncf.torch.nncf_network import ExtraCompressionModuleType from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.layers import QUANTIZATION_MODULES from nncf.torch.quantization.layers import BaseQuantizer diff --git a/nncf/torch/quantization/precision_init/hawq_debug.py b/nncf/torch/quantization/precision_init/hawq_debug.py index 697d8367bbc..df8ddfefdbf 100644 --- a/nncf/torch/quantization/precision_init/hawq_debug.py +++ b/nncf/torch/quantization/precision_init/hawq_debug.py @@ -20,7 +20,7 @@ from nncf.common.logging import nncf_logger from nncf.common.utils.decorators import skip_if_dependency_unavailable from nncf.common.utils.dot_file_rw import write_dot_graph -from nncf.torch.nncf_network import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.adjust_padding import add_adjust_padding_nodes from nncf.torch.quantization.layers import QUANTIZATION_MODULES diff --git a/nncf/torch/quantization/strip.py b/nncf/torch/quantization/strip.py index 76dfe2113bf..76cbeac741d 100644 --- a/nncf/torch/quantization/strip.py +++ b/nncf/torch/quantization/strip.py @@ -15,7 +15,7 @@ from torch.quantization.fake_quantize import FakeQuantize import nncf -from nncf.torch.nncf_network import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.layers import AsymmetricQuantizer from nncf.torch.quantization.layers import BaseQuantizer diff --git a/tests/torch/ptq/test_smooth_quant.py b/tests/torch/ptq/test_smooth_quant.py index fa5d0599672..7af4fa98d33 100644 --- a/tests/torch/ptq/test_smooth_quant.py +++ b/tests/torch/ptq/test_smooth_quant.py @@ -21,9 +21,9 @@ from nncf.quantization.algorithms.smooth_quant.torch_backend import SQMultiply from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype from nncf.torch.graph.operator_metatypes import PTModuleLinearMetatype +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.model_creation import wrap_model -from nncf.torch.nncf_network import ExtraCompressionModuleType from tests.post_training.test_templates.helpers import ConvTestModel from tests.post_training.test_templates.helpers import LinearMultiShapeModel from tests.post_training.test_templates.helpers import ShareWeghtsConvAndShareLinearModel diff --git a/tests/torch/quantization/test_algo_quantization.py b/tests/torch/quantization/test_algo_quantization.py index c0d28bcef39..4a70ceebee2 100644 --- a/tests/torch/quantization/test_algo_quantization.py +++ b/tests/torch/quantization/test_algo_quantization.py @@ -36,11 +36,11 @@ from nncf.torch.compression_method_api import PTCompressionLoss from nncf.torch.dynamic_graph.scope import Scope from nncf.torch.dynamic_graph.scope import ScopeElement +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.layers import NNCFConv2d from nncf.torch.model_creation import create_compression_algorithm_builder from nncf.torch.module_operations import UpdateInputs from nncf.torch.module_operations import UpdateWeight -from nncf.torch.nncf_network import ExtraCompressionModuleType from nncf.torch.quantization.algo import QuantizationBuilder from nncf.torch.quantization.algo import QuantizationController from nncf.torch.quantization.layers import QUANTIZATION_MODULES diff --git a/tests/torch/quantization/test_strip.py b/tests/torch/quantization/test_strip.py index 3454c839e08..1c7105f2b91 100644 --- a/tests/torch/quantization/test_strip.py +++ b/tests/torch/quantization/test_strip.py @@ -22,7 +22,7 @@ from nncf.common.quantization.quantizers import get_num_levels from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode from nncf.config import NNCFConfig -from nncf.torch.nncf_network import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.quantization.layers import AsymmetricQuantizer from nncf.torch.quantization.layers import PTQuantizerSpec from nncf.torch.quantization.layers import SymmetricQuantizer diff --git a/tests/torch/test_extractor.py b/tests/torch/test_extractor.py index b9ba7858d66..e592e6491d8 100644 --- a/tests/torch/test_extractor.py +++ b/tests/torch/test_extractor.py @@ -17,7 +17,7 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.torch import wrap_model from nncf.torch.extractor import extract_model -from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand +from nncf.torch.graph.transformations.command_creation import create_quantizer_insertion_command from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.model_transformer import PTTransformationLayout @@ -97,7 +97,7 @@ def test_extract_model(model_cls, input_node_name, output_node_name): ), ), ) -def tes_extract_model_for_node_with_fq(model_cls, input_node_name, output_node_name): +def test_extract_model_for_node_with_fq(model_cls, input_node_name, output_node_name): example_input = torch.ones(model_cls.INPUT_SIZE) model = wrap_model(model_cls().eval(), example_input=example_input, trace_parameters=True) @@ -114,7 +114,7 @@ def tes_extract_model_for_node_with_fq(model_cls, input_node_name, output_node_n ) fq = SymmetricQuantizer(qspec) - command = PTQuantizerInsertionCommand( + command = create_quantizer_insertion_command( PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, input_node_name, input_port_id=1), fq ) layout = PTTransformationLayout() @@ -125,9 +125,10 @@ def tes_extract_model_for_node_with_fq(model_cls, input_node_name, output_node_n with torch.no_grad(): ret1 = q_model(example_input) ret2 = extracted_module(example_input) - assert torch.any(torch.isclose(ret1, ret2)) + assert torch.all(torch.isclose(ret1, ret2)) + + extracted_fn = extracted_module + if isinstance(extracted_fn, nn.Sequential): + extracted_fn = extracted_module[0] - if isinstance(extracted_module, nn.Sequential): - assert extracted_module[0].w_fq is not None - else: - assert extracted_module.w_fq is not None + assert extracted_fn.fn_name is not None diff --git a/tests/torch/test_model_graph_manager.py b/tests/torch/test_model_graph_manager.py index 89c21c4b883..f1d9d743591 100644 --- a/tests/torch/test_model_graph_manager.py +++ b/tests/torch/test_model_graph_manager.py @@ -21,7 +21,7 @@ from nncf.common.graph.graph import NNCFNode from nncf.common.graph.transformations.commands import TargetType from nncf.torch import wrap_model -from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand +from nncf.torch.graph.transformations.command_creation import create_quantizer_insertion_command from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.model_graph_manager import get_const_data from nncf.torch.model_graph_manager import get_const_data_on_port @@ -268,7 +268,7 @@ def test_get_fake_quantizer(target_type, port_id): ) fq = SymmetricQuantizer(qspec) - command = PTQuantizerInsertionCommand(PTTargetPoint(target_type, node_name, input_port_id=port_id), fq) + command = create_quantizer_insertion_command(PTTargetPoint(target_type, node_name, input_port_id=port_id), fq) layout = PTTransformationLayout() layout.register(command) q_model = transformer.transform(layout) @@ -303,7 +303,9 @@ def test_is_quantized_weights(): ) fq = SymmetricQuantizer(qspec) - command = PTQuantizerInsertionCommand(PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, node_name, input_port_id=1), fq) + command = create_quantizer_insertion_command( + PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, node_name, input_port_id=1), fq + ) layout = PTTransformationLayout() layout.register(command) q_model = transformer.transform(layout) diff --git a/tests/torch/test_model_transformer.py b/tests/torch/test_model_transformer.py index a8b2172b9c2..c6348644f68 100644 --- a/tests/torch/test_model_transformer.py +++ b/tests/torch/test_model_transformer.py @@ -31,6 +31,7 @@ from nncf.common.insertion_point_graph import InsertionPointGraphNodeType from nncf.common.insertion_point_graph import PostHookInsertionPoint from nncf.common.insertion_point_graph import PreHookInsertionPoint +from nncf.common.quantization.structs import NonWeightQuantizerId from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode from nncf.common.utils.backend import BackendType from nncf.common.utils.dot_file_rw import get_graph_without_data @@ -42,17 +43,17 @@ from nncf.torch.dynamic_graph.io_handling import FillerInputInfo from nncf.torch.dynamic_graph.operation_address import OperationAddress from nncf.torch.dynamic_graph.patch_pytorch import register_operator -from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME from nncf.torch.external_hook import ExternalOpCallHook from nncf.torch.graph.operator_metatypes import PTConv2dMetatype from nncf.torch.graph.operator_metatypes import PTInputNoopMetatype from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype from nncf.torch.graph.operator_metatypes import PTOutputNoopMetatype from nncf.torch.graph.operator_metatypes import PTReshapeMetatype +from nncf.torch.graph.transformations.command_creation import create_quantizer_insertion_command +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand from nncf.torch.graph.transformations.commands import PTInsertionCommand from nncf.torch.graph.transformations.commands import PTModelExtractionWithFusedBiasCommand -from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand @@ -62,12 +63,13 @@ from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.module_operations import BaseOp from nncf.torch.module_operations import UpdateWeight -from nncf.torch.nncf_network import ExtraCompressionModuleType from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.nncf_network import PTInsertionPoint from nncf.torch.nncf_network import PTInsertionType +from nncf.torch.nncf_network import compression_module_type_to_attr_name from nncf.torch.quantization.layers import AsymmetricQuantizer from nncf.torch.quantization.layers import PTQuantizerSpec +from nncf.torch.utils import get_model_device from tests.common.quantization.mock_graphs import get_ip_graph_for_test from tests.common.quantization.mock_graphs import get_mock_model_graph_with_broken_output_edge_pattern from tests.common.quantization.mock_graphs import get_mock_model_graph_with_mergeable_pattern @@ -157,7 +159,10 @@ def test_single_insertions(self, trace_parameters, target_point: PTTargetPoint): not trace_parameters, ) if insertion_point.insertion_type in [PTInsertionType.OPERATOR_PRE_HOOK, PTInsertionType.OPERATOR_POST_HOOK]: - hook = lambda x: x + + def hook(x): + return x + else: hook = BaseOp(lambda x: x) @@ -183,6 +188,75 @@ def test_single_insertions(self, trace_parameters, target_point: PTTargetPoint): assert len(model.nncf._groups_vs_hooks_handlers[test_hook_group]) == 1 + class BaseOpWithParam(BaseOp): + def __init__(self, op): + super().__init__(op) + self.param1 = torch.nn.Parameter(torch.zeros((1,))) + self.param2 = torch.nn.Parameter(torch.zeros((1,))) + self.to_device = None + + def to(self, device): + super().to(device) + self.to_device = device + + @pytest.mark.parametrize("target_point", available_points) + @pytest.mark.parametrize("multidevice", (False, True)) + @pytest.mark.parametrize("hook", (lambda x: x, BaseOpWithParam(lambda x: x).cpu())) + def test_pt_insertion_command(self, target_point: PTTargetPoint, multidevice: bool, hook): + model = wrap_model(InsertionPointTestModel(), torch.ones([1, 1, 10, 10])) + + if multidevice: + if not torch.cuda.is_available(): + pytest.skip("Cuda is not available, could not run multidevice test case") + model.conv2.to("cuda") + + test_hook_group = "test_hook_group" + insertion_command = PTInsertionCommand(target_point, hook, hooks_group_name=test_hook_group) + layout = PTTransformationLayout() + layout.register(insertion_command) + transformer = PTModelTransformer(model) + + if target_point.target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ] and not isinstance(hook, nn.Module): + with pytest.raises(TypeError): + transformer.transform(layout) + return + transformer.transform(layout) + + insertion_point = PTInsertionPoint( + target_point.target_type, + model.nncf.get_node_to_op_address_mapping()[target_point.target_node_name], + target_point.input_port_id, + ) + + if target_point.target_type == TargetType.OPERATOR_PRE_HOOK: + ctx = model.nncf.get_tracing_context() + pre_hook_id = PreHookId(insertion_point.op_address, input_port_id=insertion_point.input_port_id) + assert ctx._pre_hooks[pre_hook_id]["0"] is hook + elif target_point.target_type == TargetType.OPERATOR_POST_HOOK: + ctx = model.nncf.get_tracing_context() + assert ctx._post_hooks[insertion_point.op_address]["0"] is hook + elif target_point.target_type == TargetType.OPERATION_WITH_WEIGHTS: + module = model.nncf.get_module_by_scope(insertion_point.module_scope) + w_hook = module.pre_ops["0"] + assert isinstance(w_hook, UpdateWeight) + assert w_hook.op is hook + elif target_point.target_type == TargetType.PRE_LAYER_OPERATION: + module = model.nncf.get_module_by_scope(insertion_point.module_scope) + assert module.pre_ops["0"] is hook + elif target_point.target_type == TargetType.POST_LAYER_OPERATION: + module = model.nncf.get_module_by_scope(insertion_point.module_scope) + assert module.post_ops["0"] is hook + else: + raise Exception(f"Not check order for {insertion_point.insertion_type}") + + if isinstance(hook, nn.Module) and not multidevice: + assert hook.to_device == get_model_device(model) + + assert len(model.nncf._groups_vs_hooks_handlers[test_hook_group]) == 1 + @staticmethod def check_order(iterable1: List, iterable2: List, ordering: List): for idx, order in enumerate(ordering): @@ -554,96 +628,110 @@ class Hook(torch.nn.Module): def __init__(self): super().__init__() self._param = torch.nn.Parameter(torch.zeros((1,))) + self.to_device = None def forward(self, x): return x + self._param + def to(self, device): + super().to(device) + self.to_device = device + @pytest.mark.parametrize( - "target_type, node_name, input_port_id, ref_name", + "target_type, node_name, input_port_id, ref_name, compression_module_registered", ( - (TargetType.OPERATOR_POST_HOOK, "/nncf_model_input_0", None, "/nncf_model_input_0|OUTPUT"), + ( + TargetType.OPERATOR_POST_HOOK, + "/nncf_model_input_0", + None, + "/nncf_model_input_0|OUTPUT", + True, + ), ( TargetType.OPERATOR_PRE_HOOK, "InsertionPointTestModel/linear_0", 0, "InsertionPointTestModel/linear_0|INPUT0", + True, ), - (TargetType.OPERATION_WITH_WEIGHTS, "InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0", None, None), + (TargetType.OPERATION_WITH_WEIGHTS, "InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0", None, None, False), ), ) -def test_quantizer_insertion_transformations(target_type, node_name, input_port_id, ref_name): +def test_quantizer_insertion_transformations( + target_type, node_name, input_port_id, ref_name, compression_module_registered +): hook = Hook() - def _insert_quantizer_to_model(): - model = NNCFNetwork(InsertionPointTestModel(), FillerInputInfo([FillerInputElement([1, 1, 10, 10])])) - model_transformer = PTModelTransformer(model) - - target_point = PTTargetPoint(target_type, node_name, input_port_id=input_port_id) - command = PTQuantizerInsertionCommand(target_point, hook) - - transformation_layout = PTTransformationLayout() - transformation_layout.register(command) - return model_transformer.transform(transformation_layout) - - transformed_model = _insert_quantizer_to_model() - - compression_module_type = ExtraCompressionModuleType.EXTERNAL_QUANTIZER - assert transformed_model.nncf.is_compression_module_registered(compression_module_type) - assert hook in transformed_model.modules() + target_point = PTTargetPoint(target_type, node_name, input_port_id=input_port_id) + command = create_quantizer_insertion_command(target_point, hook) - if target_type == TargetType.OPERATION_WITH_WEIGHTS: - op = transformed_model.conv1.pre_ops._modules["0"] - assert isinstance(op, UpdateWeight) - assert isinstance(op.op, Hook) + assert command.fn is hook + if target_point.type is TargetType.OPERATION_WITH_WEIGHTS: + assert isinstance(command, PTInsertionCommand) else: - external_quantizers = transformed_model.nncf.get_compression_modules_by_type(compression_module_type) - assert hasattr(external_quantizers, ref_name) - op = getattr(external_quantizers, ref_name) - assert isinstance(op, Hook) + quantizer_id = NonWeightQuantizerId(target_point.target_node_name, target_point.input_port_id) + assert isinstance(command, PTSharedFnInsertionCommand) + assert command.target_points == [target_point] + assert command.fn is hook + storage_key = str(quantizer_id) + assert command.op_name == storage_key + assert command.compression_module_type is ExtraCompressionModuleType.EXTERNAL_QUANTIZER - # Check torch can correctly save and load model state dict with an external quantizer - state_dict = transformed_model.state_dict() - if target_type == TargetType.OPERATION_WITH_WEIGHTS: - state_dict_hook_key = "conv1.pre_ops.0.op._param" - else: - state_dict_hook_key = f"_nncf.external_quantizers.{ref_name}._param" - assert state_dict_hook_key in state_dict - del transformed_model - transformed_model = _insert_quantizer_to_model() - transformed_model.load_state_dict(state_dict) + +SHARED_FN_TARGET_POINTS = [ + PTTargetPoint( + TargetType.OPERATOR_POST_HOOK, + "/nncf_model_input_0", + ), + PTTargetPoint( + TargetType.OPERATOR_PRE_HOOK, + "InsertionPointTestModel/linear_0", + input_port_id=0, + ), + PTTargetPoint( + TargetType.OPERATION_WITH_WEIGHTS, + "InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0", + ), +] +@pytest.mark.parametrize("compression_module_type", ExtraCompressionModuleType) @pytest.mark.parametrize( "priority", [TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION, TransformationPriority.DEFAULT_PRIORITY] ) @pytest.mark.parametrize("compression_module_registered", [False, True]) -def test_shared_fn_insertion_point(priority, compression_module_registered, mocker): - tps = [ - PTTargetPoint( - TargetType.OPERATOR_POST_HOOK, - "/nncf_model_input_0", - ), - PTTargetPoint( - TargetType.OPERATOR_PRE_HOOK, - "InsertionPointTestModel/linear_0", - input_port_id=0, - ), - PTTargetPoint( - TargetType.OPERATION_WITH_WEIGHTS, - "InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0", - ), - ] +@pytest.mark.parametrize("multidevice_model", (False, True)) +def test_shared_fn_insertion_point( + priority, compression_module_registered, compression_module_type, multidevice_model, mocker +): + if not torch.cuda.is_available() and multidevice_model: + pytest.skip("Could not test multidevice case without cuda") + + tps = SHARED_FN_TARGET_POINTS OP_UNIQUE_NAME = "UNIQUE_NAME" HOOK_GROUP_NAME = "shared_commands_hooks_group" + STORAGE_NAME = compression_module_type_to_attr_name(compression_module_type) hook_instance = Hook() def _insert_external_op_mocked(): model = NNCFNetwork(InsertionPointTestModel(), FillerInputInfo([FillerInputElement([1, 1, 10, 10])])) + model = model.cpu() + if multidevice_model: + model.conv1.to(torch.device("cpu")) + model.conv2.to(torch.device("cuda")) + if compression_module_registered: - model.nncf.register_compression_module_type(ExtraCompressionModuleType.EXTERNAL_OP) + model.nncf.register_compression_module_type(compression_module_type) unique_name = f"{OP_UNIQUE_NAME}[{';'.join([tp.target_node_name for tp in tps])}]" - command = PTSharedFnInsertionCommand(tps, hook_instance, unique_name, priority, HOOK_GROUP_NAME) + command = PTSharedFnInsertionCommand( + target_points=tps, + fn=hook_instance, + op_unique_name=unique_name, + compression_module_type=compression_module_type, + priority=priority, + hooks_group_name=HOOK_GROUP_NAME, + ) transformation_layout = PTTransformationLayout() transformation_layout.register(command) @@ -658,38 +746,126 @@ def _insert_external_op_mocked(): transformed_model = _insert_external_op_mocked() - assert transformed_model.nncf.is_compression_module_registered(ExtraCompressionModuleType.EXTERNAL_OP) + assert transformed_model.nncf.is_compression_module_registered(compression_module_type) REF_STORAGE_KEY = ( "UNIQUE_NAME[/nncf_model_input_0;InsertionPointTestModel/linear_0;" "InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0]" ) - storage = getattr(transformed_model.nncf, EXTERNAL_OP_STORAGE_NAME) + storage = getattr(transformed_model.nncf, STORAGE_NAME) assert storage[REF_STORAGE_KEY] is hook_instance assert hook_instance in transformed_model.modules() mock = PTModelTransformer._apply_insertion_transformations mock.assert_called_once() - _, commands = mock.call_args.args + _, commands, device = mock.call_args.args assert len(commands) == len(tps) for command in commands: assert command.target_point in tps assert command.hooks_group_name == HOOK_GROUP_NAME + assert command.priority == priority fn = command.fn assert isinstance(fn, ExternalOpCallHook) - assert fn._storage_name == EXTERNAL_OP_STORAGE_NAME + assert fn._storage_name == STORAGE_NAME assert fn._storage_key == REF_STORAGE_KEY + if multidevice_model: + assert hook_instance.to_device is None + assert device is None + else: + actual_model_device = get_model_device(transformed_model) + assert hook_instance.to_device == actual_model_device + assert device == actual_model_device + # Check torch can correctly save and load model state dict with an external quantizer state_dict = transformed_model.state_dict() - assert f"_nncf.{EXTERNAL_OP_STORAGE_NAME}.{REF_STORAGE_KEY}._param" in state_dict + assert f"_nncf.{STORAGE_NAME}.{REF_STORAGE_KEY}._param" in state_dict del transformed_model transformed_model = _insert_external_op_mocked() transformed_model.load_state_dict(state_dict) +@pytest.mark.parametrize( + "priority", [TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION, TransformationPriority.DEFAULT_PRIORITY] +) +@pytest.mark.parametrize("compression_module_registered", [False, True]) +@pytest.mark.parametrize("multidevice_model", (False, True)) +def test_shared_fn_insertion_command_several_module_types( + priority, compression_module_registered, multidevice_model, mocker +): + if not torch.cuda.is_available() and multidevice_model: + pytest.skip("Could not test multidevice case without cuda") + + tps = SHARED_FN_TARGET_POINTS + OP_UNIQUE_NAME = "UNIQUE_NAME" + HOOK_GROUP_NAME = "shared_commands_hooks_group" + MODULE_TYPES = [t for t in ExtraCompressionModuleType] + hook_instance = Hook() + + def _insert_external_op_mocked(): + model = NNCFNetwork(InsertionPointTestModel(), FillerInputInfo([FillerInputElement([1, 1, 10, 10])])) + model = model.cpu() + if multidevice_model: + model.conv1.to(torch.device("cpu")) + model.conv2.to(torch.device("cuda")) + + transformation_layout = PTTransformationLayout() + for compression_module_type in MODULE_TYPES: + if compression_module_registered: + model.nncf.register_compression_module_type(compression_module_type) + unique_name = f"{OP_UNIQUE_NAME}[{';'.join([tp.target_node_name for tp in tps])}]" + command = PTSharedFnInsertionCommand( + target_points=tps, + fn=hook_instance, + op_unique_name=unique_name, + compression_module_type=compression_module_type, + priority=priority, + hooks_group_name=HOOK_GROUP_NAME, + ) + transformation_layout.register(command) + + mocker.MagicMock() + mocker.patch( + "nncf.torch.model_transformer.PTModelTransformer._apply_shared_node_insertion_with_compression_type", + return_value=mocker.MagicMock(), + ) + model_transformer = PTModelTransformer(model) + model_transformer.transform(transformation_layout=transformation_layout) + return model + + transformed_model = _insert_external_op_mocked() + + mock = PTModelTransformer._apply_shared_node_insertion_with_compression_type + assert len(mock.call_args_list) == len(MODULE_TYPES) + + REF_STORAGE_KEY = ( + "UNIQUE_NAME[/nncf_model_input_0;InsertionPointTestModel/linear_0;" + "InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0]" + ) + + module_types_set = set(MODULE_TYPES) + for (_, commands, device, compression_module_type), _ in mock.call_args_list: + module_types_set -= set((compression_module_type,)) + assert len(commands) == 1 + command = commands[0] + assert isinstance(command, PTSharedFnInsertionCommand) + assert command.fn is hook_instance + assert command.target_points is tps + assert command.compression_module_type == compression_module_type + assert command.op_name == REF_STORAGE_KEY + assert command.priority == priority + assert command.hooks_group_name == HOOK_GROUP_NAME + + if multidevice_model: + assert device is None + else: + assert device == get_model_device(transformed_model) + + assert not module_types_set + + INSERTION_POINT_TEST_MODEL_TARGET_POINTS = ( ( TargetType.OPERATOR_POST_HOOK, diff --git a/tests/torch/test_nncf_network.py b/tests/torch/test_nncf_network.py index 938eae5f73c..c4da4be8c82 100644 --- a/tests/torch/test_nncf_network.py +++ b/tests/torch/test_nncf_network.py @@ -39,11 +39,11 @@ from nncf.torch.graph.graph_builder import GraphBuilder from nncf.torch.graph.operator_metatypes import PTConv2dMetatype from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.layer_utils import _NNCFModuleMixin from nncf.torch.layers import NNCFConv2d from nncf.torch.model_creation import wrap_model from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules -from nncf.torch.nncf_network import ExtraCompressionModuleType from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.nncf_network import PTInsertionPoint from nncf.torch.nncf_network import PTInsertionType diff --git a/tests/torch/test_tracing_context.py b/tests/torch/test_tracing_context.py index 37642fad5f6..75b127eb2d6 100644 --- a/tests/torch/test_tracing_context.py +++ b/tests/torch/test_tracing_context.py @@ -17,7 +17,7 @@ from nncf.torch.dynamic_graph.trace_tensor import TracedParameter from nncf.torch.dynamic_graph.trace_tensor import TracedTensor from nncf.torch.dynamic_graph.wrappers import wrap_parameters -from nncf.torch.nncf_network import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from tests.torch.helpers import BasicConvTestModel