From 15f5428c68f287a99512dcdd901404ef313e32a1 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 11 Apr 2024 20:01:51 +0200 Subject: [PATCH] Comments --- nncf/torch/graph/graph.py | 6 +- nncf/torch/nncf_network.py | 49 ++++- tests/torch/helpers.py | 176 ++++------------ tests/torch/nncf_network/helpers.py | 194 +++++++++++++++++ .../test_get_applied_modifications.py | 162 ++++++++++++++ .../torch/nncf_network/test_hook_handlers.py | 119 +++++++++++ .../{ => nncf_network}/test_nncf_network.py | 197 +----------------- tests/torch/test_api_behavior.py | 2 +- 8 files changed, 558 insertions(+), 347 deletions(-) create mode 100644 tests/torch/nncf_network/helpers.py create mode 100644 tests/torch/nncf_network/test_get_applied_modifications.py create mode 100644 tests/torch/nncf_network/test_hook_handlers.py rename tests/torch/{ => nncf_network}/test_nncf_network.py (80%) diff --git a/nncf/torch/graph/graph.py b/nncf/torch/graph/graph.py index 63637f759b0..588866367c0 100644 --- a/nncf/torch/graph/graph.py +++ b/nncf/torch/graph/graph.py @@ -61,11 +61,7 @@ def get_op_nodes_in_scope(self, scope: Scope) -> List[NNCFNode]: return matching_graph_op_nodes def get_op_node_in_scope(self, scope: Scope) -> List[NNCFNode]: - for scope_str, nodes_in_module in self._layer_name_vs_shared_nodes.items(): - module_scope = Scope.from_str(scope_str) - if module_scope == scope: - return nodes_in_module - return [] + return self._layer_name_vs_shared_nodes[str(scope)] def get_scope_by_node_name(self, node_name: NNCFNodeName) -> Scope: matches = [] diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index d02f4038d31..01f3201e1e7 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -794,19 +794,19 @@ def get_applied_transformation_layout(self) -> PTTransformationLayout: :return: Transformation layout with all commands applied to the NNCFNetwork. """ - def _create_pt_insert_command(module, target_type, target_node_name, priority, input_port_id): - target_point = PTTargetPoint( - target_type=target_type, target_node_name=target_node_name, input_port_id=input_port_id - ) - return PTInsertionCommand(point=target_point, fn=module, priority=priority) - def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str): + """ + Check given external op call hook reference is correct. + + :param hook: External op call hook to check correctness. + :param info: Info to log in case op call hook references are broken. + """ assert hasattr( self, hook._storage_name ), f"Storage name {hook._storage_name} is not registered. Info: {info}" assert hook._storage_key in getattr( self, hook._storage_name - ), f"Storage key {hook._storage_key} is not registered. Info: {info}" + ), f"Key {hook._storage_key} is not registered in {hook._storage_name}. Info: {info}" context_hooks = defaultdict(lambda: defaultdict(list)) transformation_layout = PTTransformationLayout() @@ -828,14 +828,16 @@ def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str): command_target_type = TargetType.OPERATION_WITH_WEIGHTS module = module.op if not isinstance(module, ExternalOpCallHook): - command = _create_pt_insert_command( + command = create_pt_insertion_command( module, command_target_type, nncf_node.node_name, priority, None ) transformation_layout.register(command) continue - info = f"TargetType: {command_target_type}, nncf node name: {nncf_node.node_name}," - f" priority: {priority}, fn: {module}" + info = ( + f"TargetType: {command_target_type}, nncf node name: {nncf_node.node_name}," + f" priority: {priority}, fn: {module}" + ) _check_external_call_hook_is_valid(module, info) context_hooks[module._storage_name][module._storage_key].append( @@ -859,7 +861,9 @@ def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str): target_node_name = target_node_names[0] if not isinstance(fn, ExternalOpCallHook): - command = _create_pt_insert_command(fn, target_type, target_node_name, priority, input_port_id) + command = create_pt_insertion_command( + fn, target_type, target_node_name, priority, input_port_id + ) transformation_layout.register(command) continue @@ -1262,3 +1266,26 @@ def compression_module_type_to_attr_name(compression_module_type: ExtraCompressi if compression_module_type == ExtraCompressionModuleType.EXTERNAL_OP: return EXTERNAL_OP_STORAGE_NAME raise nncf.ValidationError("Unknown extra module type") + + +def create_pt_insertion_command( + module: torch.nn.Module, + target_type: TargetType, + target_node_name: str, + priority: int, + input_port_id: Optional[int], +) -> PTInsertionCommand: + """ + Creates a PTInsertionCommand. + + :param module: Torch module to insert. + :param target_type: Insertion command target type. + :param target_name: Insertion command target name. + :param priority: Insertion command priority. + :param input_port_id: Insertion command input port id. + :return: A PTInsertionCommand + """ + target_point = PTTargetPoint( + target_type=target_type, target_node_name=target_node_name, input_port_id=input_port_id + ) + return PTInsertionCommand(point=target_point, fn=module, priority=priority) diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index be02bdec415..4c40f218163 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -10,8 +10,6 @@ # limitations under the License. import contextlib -import functools -import itertools import numbers from abc import ABC from abc import abstractmethod @@ -32,8 +30,6 @@ import nncf from nncf.common.graph.transformations.commands import TargetType -from nncf.common.graph.transformations.commands import TransformationPriority -from nncf.common.graph.transformations.layout import TransformationLayout from nncf.config import NNCFConfig from nncf.config.extractors import extract_algorithm_names from nncf.config.structures import BNAdaptationInitArgs @@ -43,13 +39,11 @@ from nncf.torch.dynamic_graph.io_handling import FillerInputInfo from nncf.torch.dynamic_graph.operation_address import OperationAddress from nncf.torch.dynamic_graph.scope import Scope -from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.graph.transformations.commands import PTInsertionCommand from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.initialization import PTInitializingDataLoader from nncf.torch.initialization import register_default_init_args -from nncf.torch.layer_utils import COMPRESSION_MODULES from nncf.torch.layers import NNCF_MODULES_MAP from nncf.torch.model_creation import create_compressed_model from nncf.torch.module_operations import UpdateWeight @@ -183,10 +177,6 @@ def nz_bias_num(self): class TwoConvTestModel(nn.Module): INPUT_SHAPE = [1, 1, 4, 4] - NNCF_CONV_NODES_NAMES = [ - "TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0", - "TwoConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0", - ] def __init__(self): super().__init__() @@ -214,113 +204,6 @@ def nz_weights_num(self): def nz_bias_num(self): return 2 - @staticmethod - def create_pt_insertion_command( - target_type: TargetType, priority: TransformationPriority, fn=None, group: str = "default_group" - ): - target_point = PTTargetPoint( - target_type=target_type, target_node_name=TwoConvTestModel.NNCF_CONV_NODES_NAMES[0], input_port_id=0 - ) - if fn is None: - fn = DummyOpWithState("DUMMY_STATE") - return PTInsertionCommand(point=target_point, fn=fn, priority=priority, hooks_group_name=group) - - @staticmethod - def create_pt_shared_fn_insertion_command( - target_type: TargetType, - priority: TransformationPriority, - compression_module_type: ExtraCompressionModuleType, - fn=None, - group: str = "default_group", - op_unique_name: str = "UNIQUE_NAME", - ): - target_points = [] - - for node_name in TwoConvTestModel.NNCF_CONV_NODES_NAMES: - target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0)) - if fn is None: - fn = DummyOpWithState("DUMMY_STATE") - return PTSharedFnInsertionCommand( - target_points=target_points, - fn=fn, - compression_module_type=compression_module_type, - op_unique_name=op_unique_name, - priority=priority, - hooks_group_name=group, - ) - - AVAILABLE_TARGET_TYPES = ( - TargetType.OPERATION_WITH_WEIGHTS, - TargetType.OPERATOR_PRE_HOOK, - TargetType.OPERATOR_POST_HOOK, - TargetType.PRE_LAYER_OPERATION, - TargetType.POST_LAYER_OPERATION, - ) - - @staticmethod - def get_command_builders(): - return ( - TwoConvTestModel.create_pt_insertion_command, - functools.partial( - TwoConvTestModel.create_pt_shared_fn_insertion_command, - compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP, - ), - functools.partial( - TwoConvTestModel.create_pt_shared_fn_insertion_command, - compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, - ), - ) - - COMMAND_TYPES = [PTInsertionCommand, PTSharedFnInsertionCommand, PTSharedFnInsertionCommand] - PRIORITIES = (TransformationPriority.QUANTIZATION_PRIORITY, TransformationPriority.QUANTIZATION_PRIORITY.value + 1) - - @classmethod - def get_all_available_commands( - cls, dummy_op_state, skip_model_transformer_unsupported=False - ) -> TransformationLayout: - """ - Returns all possible commands to insert: - all target types x all command class x all compression module types x different priorities. - """ - layout = TransformationLayout() - for idx, (target_type, (command_builder, command_type), priority) in enumerate( - itertools.product( - cls.AVAILABLE_TARGET_TYPES, zip(cls.get_command_builders(), cls.COMMAND_TYPES), cls.PRIORITIES - ) - ): - if command_type is PTSharedFnInsertionCommand: - if skip_model_transformer_unsupported and target_type in [ - TargetType.PRE_LAYER_OPERATION, - TargetType.POST_LAYER_OPERATION, - ]: - continue - command = cls._create_command( - command_builder, target_type, priority, dummy_op_state, op_unique_name=f"UNIQUE_NAME_{idx}" - ) - else: - command = cls._create_command(command_builder, target_type, priority, dummy_op_state) - - layout.register(command) - return layout - - @staticmethod - def _create_command(command_builder, target_type, priority, dummy_op_state, op_unique_name=None): - group_name = "CUSTOM_HOOKS_GROUP_NAME" - - if DummyOpWithState.__name__ not in COMPRESSION_MODULES.registry_dict: - registered_dummy_op_cls = COMPRESSION_MODULES.register()(DummyOpWithState) - else: - registered_dummy_op_cls = DummyOpWithState - dummy_op = registered_dummy_op_cls(dummy_op_state) - if op_unique_name is None: - command = command_builder(target_type, priority, fn=dummy_op, group=group_name) - else: - command = command_builder( - target_type, priority, fn=dummy_op, group=group_name, op_unique_name=op_unique_name - ) - - return command - class LeNet(nn.Module): INPUT_SIZE = 1, 32, 32 @@ -372,38 +255,61 @@ def from_state(cls, state: str): return cls(state) -def target_points_are_equal(tp_original: PTTargetPoint, tp_recovered: PTTargetPoint): - if tp_original != tp_recovered: +def target_points_are_equal(tp_left: PTTargetPoint, tp_right: PTTargetPoint) -> bool: + """ + Returns True if given target points are equal and False elsewhere. + + :param tp_left: The first target point. + :param tp_right: The second target point. + :return: True if given target points are equal and False elsewhere. + """ + if tp_left != tp_right: return False - if tp_original.target_type == TargetType.OPERATOR_PRE_HOOK: - return tp_original.input_port_id == tp_recovered.input_port_id + if tp_left.target_type == TargetType.OPERATOR_PRE_HOOK: + return tp_left.input_port_id == tp_right.input_port_id return True -def are_commands_equal( - command, applied_command, check_priority: bool = True, check_hooks_group_name: bool = True, check_fn_ref=True -): - if type(applied_command) is not type(command): +def commands_are_equal( + command_left: Union[PTInsertionCommand, PTSharedFnInsertionCommand], + command_right: Union[PTInsertionCommand, PTSharedFnInsertionCommand], + check_priority: bool = True, + check_hooks_group_name: bool = True, + check_fn_ref=True, +) -> bool: + """ + Returns True if given commands are equal and False elsewhere. + + :param command_left: The first command. + :param command_right: The second command. + :param check_priority: Whether to check insertion priority or not. + :param check_hooks_group_name: Whether to check hooks group name or not. + :param check_fn_ref: Whether to check fn by reference or not. + :returns: True if given commands are equal and False elsewhere. + """ + if type(command_right) is not type(command_left): return False # Check reference to functions are equal. - if check_fn_ref and applied_command.fn is not command.fn: + if check_fn_ref and command_right.fn is not command_left.fn: return False - if check_hooks_group_name and applied_command.hooks_group_name != command.hooks_group_name: + if check_hooks_group_name and command_right.hooks_group_name != command_left.hooks_group_name: return False - if check_priority and applied_command.priority != command.priority: + if check_priority and command_right.priority != command_left.priority: return False - if isinstance(applied_command, PTInsertionCommand): - if not target_points_are_equal(command.target_point, applied_command.target_point): + if isinstance(command_right, PTInsertionCommand): + if not target_points_are_equal(command_left.target_point, command_right.target_point): return False - elif isinstance(applied_command, PTSharedFnInsertionCommand): - if not all(target_points_are_equal(a, b) for a, b in zip(command.target_points, applied_command.target_points)): + elif isinstance(command_right, PTSharedFnInsertionCommand): + if not all( + target_points_are_equal(a, b) for a, b in zip(command_left.target_points, command_right.target_points) + ): return False if ( - applied_command.target_points != command.target_points - or applied_command.op_name != command.op_name - or applied_command.compression_module_type != command.compression_module_type + command_right.target_points != command_left.target_points + or command_right.op_name != command_left.op_name + or command_right.compression_module_type != command_left.compression_module_type ): return False else: diff --git a/tests/torch/nncf_network/helpers.py b/tests/torch/nncf_network/helpers.py new file mode 100644 index 00000000000..a16dba2b263 --- /dev/null +++ b/tests/torch/nncf_network/helpers.py @@ -0,0 +1,194 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import itertools + +import torch + +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.layer_utils import COMPRESSION_MODULES +from tests.torch.helpers import DummyOpWithState + + +class SimplestModel(torch.nn.Module): + INPUT_SIZE = [1, 1, 32, 32] + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + return self.conv(x) + + +AVAILABLE_TARGET_TYPES = ( + TargetType.OPERATION_WITH_WEIGHTS, + TargetType.OPERATOR_PRE_HOOK, + TargetType.OPERATOR_POST_HOOK, + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, +) + + +class InsertionCommandBuilder: + """ + Contains methods which allows to build all possible commands + for the TwoConvTestModel + """ + + NNCF_CONV_NODES_NAMES = [ + "TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0", + "TwoConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0", + ] + CONV_NODES_NAMES = [ + "TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0", + "TwoConvTestModel/Sequential[features]/Sequential[1]/Conv2d[0]/conv2d_0", + ] + + TRACE_VS_NODE_NAMES = {True: CONV_NODES_NAMES, False: NNCF_CONV_NODES_NAMES} + + @classmethod + def create_pt_insertion_command( + cls, + target_type: TargetType, + priority: TransformationPriority, + trace_parameters: bool, + fn=None, + group: str = "default_group", + ): + target_point = PTTargetPoint( + target_type=target_type, target_node_name=cls.TRACE_VS_NODE_NAMES[trace_parameters][0], input_port_id=0 + ) + if fn is None: + fn = DummyOpWithState("DUMMY_STATE") + return PTInsertionCommand(point=target_point, fn=fn, priority=priority, hooks_group_name=group) + + @classmethod + def create_pt_shared_fn_insertion_command( + cls, + target_type: TargetType, + priority: TransformationPriority, + trace_parameters: bool, + compression_module_type: ExtraCompressionModuleType, + fn=None, + group: str = "default_group", + op_unique_name: str = "UNIQUE_NAME", + ): + target_points = [] + + for node_name in cls.TRACE_VS_NODE_NAMES[trace_parameters]: + target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0)) + if fn is None: + fn = DummyOpWithState("DUMMY_STATE") + return PTSharedFnInsertionCommand( + target_points=target_points, + fn=fn, + compression_module_type=compression_module_type, + op_unique_name=op_unique_name, + priority=priority, + hooks_group_name=group, + ) + + @staticmethod + def get_command_builders(): + return ( + InsertionCommandBuilder.create_pt_insertion_command, + functools.partial( + InsertionCommandBuilder.create_pt_shared_fn_insertion_command, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP, + ), + functools.partial( + InsertionCommandBuilder.create_pt_shared_fn_insertion_command, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, + ), + ) + + @classmethod + def get_command_builders_with_types(cls): + return tuple(zip(cls.get_command_builders(), cls.COMMAND_TYPES)) + + COMMAND_TYPES = [PTInsertionCommand, PTSharedFnInsertionCommand, PTSharedFnInsertionCommand] + PRIORITIES = (TransformationPriority.QUANTIZATION_PRIORITY, TransformationPriority.QUANTIZATION_PRIORITY.value + 1) + + @classmethod + def get_all_available_commands( + cls, dummy_op_state, trace_parameters, skip_model_transformer_unsupported=False + ) -> TransformationLayout: + """ + Returns all possible commands to insert: + all target types x all command class x all compression module types x different priorities. + """ + layout = TransformationLayout() + for idx, (target_type, (command_builder, command_type), priority) in enumerate( + itertools.product( + AVAILABLE_TARGET_TYPES, zip(cls.get_command_builders(), cls.COMMAND_TYPES), cls.PRIORITIES + ) + ): + if command_type is PTSharedFnInsertionCommand: + if skip_model_transformer_unsupported and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: + continue + command = cls._create_command( + command_builder, + target_type, + priority, + dummy_op_state, + op_unique_name=f"UNIQUE_NAME_{idx}", + trace_parameters=trace_parameters, + ) + else: + command = cls._create_command( + command_builder, target_type, priority, dummy_op_state, trace_parameters=trace_parameters + ) + + layout.register(command) + return layout + + @staticmethod + def _create_command( + command_builder, + target_type, + priority, + dummy_op_state, + trace_parameters, + op_unique_name=None, + ): + group_name = "CUSTOM_HOOKS_GROUP_NAME" + + if DummyOpWithState.__name__ not in COMPRESSION_MODULES.registry_dict: + registered_dummy_op_cls = COMPRESSION_MODULES.register()(DummyOpWithState) + else: + registered_dummy_op_cls = DummyOpWithState + dummy_op = registered_dummy_op_cls(dummy_op_state) + if op_unique_name is None: + command = command_builder( + target_type, priority, fn=dummy_op, group=group_name, trace_parameters=trace_parameters + ) + else: + command = command_builder( + target_type, + priority, + fn=dummy_op, + group=group_name, + op_unique_name=op_unique_name, + trace_parameters=trace_parameters, + ) + + return command diff --git a/tests/torch/nncf_network/test_get_applied_modifications.py b/tests/torch/nncf_network/test_get_applied_modifications.py new file mode 100644 index 00000000000..d19840821c1 --- /dev/null +++ b/tests/torch/nncf_network/test_get_applied_modifications.py @@ -0,0 +1,162 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.torch import wrap_model +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.graph.transformations.layout import PTTransformationLayout +from nncf.torch.model_transformer import PTModelTransformer +from tests.torch.helpers import TwoConvTestModel +from tests.torch.helpers import commands_are_equal +from tests.torch.nncf_network.helpers import AVAILABLE_TARGET_TYPES +from tests.torch.nncf_network.helpers import InsertionCommandBuilder + +TARGET_TYPE_VS_TARGET_TYPE_DICT_FOR_NOT_REPLACED_MODULES = { + TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK, + TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, + TargetType.OPERATION_WITH_WEIGHTS: TargetType.OPERATOR_PRE_HOOK, + TargetType.OPERATOR_PRE_HOOK: TargetType.OPERATOR_PRE_HOOK, + TargetType.OPERATOR_POST_HOOK: TargetType.OPERATOR_POST_HOOK, +} + + +@pytest.fixture(name="trace_parameters", params=(True, False)) +def trace_parameters_fixture(request) -> bool: + return request.param + + +def _translate_target_types(trace_parameters, command): + """ + Translates target types in case trace_parameters is True + """ + if not trace_parameters: + return + if isinstance(command, PTInsertionCommand): + target_points = [command.target_point] + else: + target_points = command.target_points + + for target_point in target_points: + new_target_type = TARGET_TYPE_VS_TARGET_TYPE_DICT_FOR_NOT_REPLACED_MODULES[target_point.type] + target_point._target_type = new_target_type + target_point.target_type = new_target_type + + +@pytest.mark.parametrize("target_type", AVAILABLE_TARGET_TYPES) +@pytest.mark.parametrize("command_builder", InsertionCommandBuilder.get_command_builders()) +def test_get_applied_modification_commands(command_builder, target_type, trace_parameters): + command = command_builder(target_type, TransformationPriority.DEFAULT_PRIORITY, trace_parameters=trace_parameters) + if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: + pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") + + model = TwoConvTestModel() + nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters) + model_transformer = PTModelTransformer(nncf_model) + + layout = PTTransformationLayout() + layout.register(command) + model_transformer.transform(layout) + + applied_commands = nncf_model.nncf.get_applied_transformation_layout() + + assert len(applied_commands.transformations) == 1 + applied_command = applied_commands.transformations[0] + _translate_target_types(trace_parameters, command) + assert commands_are_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) + + +@pytest.mark.parametrize("target_type", AVAILABLE_TARGET_TYPES) +@pytest.mark.parametrize("command_builder,command_type", InsertionCommandBuilder.get_command_builders_with_types()) +def test_priority_of_get_applied_modification_commands(command_builder, target_type, command_type, trace_parameters): + layout = PTTransformationLayout() + commands = dict() + for priority in (0, 3, 2, 4, 1): + if command_type is PTSharedFnInsertionCommand: + command = command_builder( + target_type, priority, op_unique_name=f"UNIQUE_NAME_{priority}", trace_parameters=trace_parameters + ) + else: + command = command_builder(target_type, priority, trace_parameters=trace_parameters) + layout.register(command) + commands[priority] = command + else: + if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: + pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") + + model = TwoConvTestModel() + nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters) + model_tranformer = PTModelTransformer(nncf_model) + + model_tranformer.transform(layout) + + applied_commands = nncf_model.nncf.get_applied_transformation_layout() + assert len(applied_commands.transformations) == len(commands) + for applied_command in applied_commands.transformations: + command = commands[applied_command.priority] + _translate_target_types(trace_parameters, command) + assert commands_are_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) + + +def test_all_possible_combinations_of_commands_for_get_applied_commands(trace_parameters): + dummy_state = "DummyState" + commands = InsertionCommandBuilder.get_all_available_commands( + dummy_state, skip_model_transformer_unsupported=True, trace_parameters=trace_parameters + ) + + model = TwoConvTestModel() + nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters) + model_tranformer = PTModelTransformer(nncf_model) + + model_tranformer.transform(commands) + + applied_commands = nncf_model.nncf.get_applied_transformation_layout() + assert len(applied_commands.transformations) == len(commands.transformations) + for command in commands.transformations: + _translate_target_types(trace_parameters, command) + eq_commands = ( + commands_are_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) + for applied_command in applied_commands.transformations + ) + if sum(map(int, eq_commands)) != 1: + raise RuntimeError(f"Command {command} has no pair in recovered commands") + + +@pytest.mark.parametrize("target_type", (TargetType.OPERATION_WITH_WEIGHTS, TargetType.OPERATOR_PRE_HOOK)) +def test_get_applied_modification_commands_broken_call_hook(target_type, trace_parameters): + model = TwoConvTestModel() + nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters) + model_tranformer = PTModelTransformer(nncf_model) + + layout = PTTransformationLayout() + command = InsertionCommandBuilder.create_pt_shared_fn_insertion_command( + target_type=target_type, + priority=0, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP, + trace_parameters=trace_parameters, + ) + layout.register(command) + model_tranformer.transform(layout) + + nncf_model.nncf.external_op.clear() + with pytest.raises(AssertionError): + nncf_model.nncf.get_applied_transformation_layout() diff --git a/tests/torch/nncf_network/test_hook_handlers.py b/tests/torch/nncf_network/test_hook_handlers.py new file mode 100644 index 00000000000..06c22d24a84 --- /dev/null +++ b/tests/torch/nncf_network/test_hook_handlers.py @@ -0,0 +1,119 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Tuple + +import pytest +import torch + +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.hook_handle import HookHandle +from nncf.torch.dynamic_graph.io_handling import ExampleInputInfo +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.nncf_network import PTInsertionPoint +from tests.torch.helpers import HookChecker +from tests.torch.nncf_network.helpers import SimplestModel + + +@pytest.mark.parametrize( + "target_type, target_node_name, input_port_id", + [ + (TargetType.OPERATOR_PRE_HOOK, "/nncf_model_output_0", 0), + (TargetType.OPERATOR_POST_HOOK, "/nncf_model_input_0", 0), + (TargetType.PRE_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), + (TargetType.POST_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), + ], +) +class TestHookHandles: + class TestHook(torch.nn.Module): + def __init__(self): + super().__init__() + self._p = torch.nn.Parameter(torch.zeros((1,))) + + def forward(self, x): + return x + self._p + + @staticmethod + def _prepare_hook_handles_test( + target_type: TargetType, target_node_name: str, input_port_id: int + ) -> Tuple[NNCFNetwork, PTInsertionPoint, Callable[[List[HookHandle]], None]]: + model = SimplestModel() + example_input = torch.ones(SimplestModel.INPUT_SIZE) + input_info = ExampleInputInfo.from_example_input(example_input) + nncf_model = NNCFNetwork(model, input_info) + + node_name_vs_address = nncf_model.nncf.get_node_to_op_address_mapping() + ip = PTInsertionPoint(target_type, node_name_vs_address[target_node_name], input_port_id=input_port_id) + + checker = HookChecker(nncf_model, "conv") + + def _check(ref_hooks_): + checker.clear() + checker.add_ref(ref_hooks_, target_type, target_node_name, input_port_id) + checker.check_with_reference() + + return nncf_model, ip, _check + + def test_temporary_insert_at_point_by_hook_group_name( + self, target_type: TargetType, target_node_name: str, input_port_id: int + ): + nncf_model, ip, _check = self._prepare_hook_handles_test(target_type, target_node_name, input_port_id) + permanent_hook = self.TestHook() + TEMPORARY_HOOK_GROUP_NAME = "tmp" + # Make temporary hook a ref to the permanent hook + # to check tmp hooks are not removed by their id() + temporary_hook = permanent_hook + nncf_model.nncf.insert_at_point(ip, permanent_hook) + ref_hooks = [permanent_hook] + _check(ref_hooks) + + for _ in range(2): + temporary_hook = self.TestHook() + nncf_model.nncf.insert_at_point(ip, temporary_hook, TEMPORARY_HOOK_GROUP_NAME) + ref_hooks.append(temporary_hook) + _check(ref_hooks) + + nncf_model.nncf.insert_at_point(ip, permanent_hook) + ref_hooks.append(permanent_hook) + _check(ref_hooks) + + nncf_model.nncf.remove_hooks_group(TEMPORARY_HOOK_GROUP_NAME) + del ref_hooks[-2] + _check(ref_hooks) + assert not nncf_model.nncf._groups_vs_hooks_handlers[TEMPORARY_HOOK_GROUP_NAME] + + def test_insert_at_point_hook_handles(self, target_type: TargetType, target_node_name: str, input_port_id: int): + nncf_model, ip, _check = self._prepare_hook_handles_test(target_type, target_node_name, input_port_id) + permanent_hook = self.TestHook() + # Make temporary hook a ref to the permanent hook + # to check tmp hooks are not removed by their id() + temporary_hook = permanent_hook + tmp_hh = [] + nncf_model.nncf.insert_at_point(ip, permanent_hook) + + ref_hooks = [permanent_hook] + _check(ref_hooks) + + for _ in range(2): + temporary_hook = self.TestHook() + tmp_hh.append(nncf_model.nncf.insert_at_point(ip, temporary_hook)) + ref_hooks.append(temporary_hook) + _check(ref_hooks) + + nncf_model.nncf.insert_at_point(ip, permanent_hook) + ref_hooks.append(permanent_hook) + _check(ref_hooks) + + for hh in tmp_hh: + hh.remove() + + del ref_hooks[-2] + _check(ref_hooks) diff --git a/tests/torch/test_nncf_network.py b/tests/torch/nncf_network/test_nncf_network.py similarity index 80% rename from tests/torch/test_nncf_network.py rename to tests/torch/nncf_network/test_nncf_network.py index 331d8c8e951..5fff987ea44 100644 --- a/tests/torch/test_nncf_network.py +++ b/tests/torch/nncf_network/test_nncf_network.py @@ -14,7 +14,7 @@ from abc import ABCMeta from abc import abstractmethod from copy import deepcopy -from typing import Callable, List, Tuple, Type +from typing import Callable, Type import pytest import torch @@ -27,8 +27,6 @@ from nncf.common.graph import NNCFNode from nncf.common.graph.operator_metatypes import UnknownMetatype from nncf.common.graph.transformations.commands import TargetType -from nncf.common.graph.transformations.commands import TransformationPriority -from nncf.common.hook_handle import HookHandle from nncf.torch import register_module from nncf.torch.dynamic_graph.io_handling import ExampleInputInfo from nncf.torch.dynamic_graph.io_handling import FillerInputElement @@ -41,12 +39,9 @@ from nncf.torch.graph.operator_metatypes import PTConv2dMetatype from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType -from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand -from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.layer_utils import _NNCFModuleMixin from nncf.torch.layers import NNCFConv2d from nncf.torch.model_creation import wrap_model -from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.nncf_network import PTInsertionPoint @@ -54,12 +49,11 @@ from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME from tests.torch.composite.test_sparsity_quantization import get_basic_sparsity_plus_quantization_config from tests.torch.helpers import BasicConvTestModel -from tests.torch.helpers import HookChecker from tests.torch.helpers import TwoConvTestModel -from tests.torch.helpers import are_commands_equal from tests.torch.helpers import check_correct_nncf_modules_replacement from tests.torch.helpers import create_compressed_model_and_algo_for_test from tests.torch.helpers import register_bn_adaptation_init_args +from tests.torch.nncf_network.helpers import SimplestModel from tests.torch.test_models.synthetic import ManyNonEvalModules @@ -618,17 +612,6 @@ def test_can_work_with_sequential_models(): _ = model.nncf.get_clean_shallow_copy() -class SimplestModel(torch.nn.Module): - INPUT_SIZE = [1, 1, 32, 32] - - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(1, 1, 1) - - def forward(self, x): - return self.conv(x) - - @pytest.fixture(name="simple_net") def simple_net_(): model = NNCFNetwork(SimplestModel(), FillerInputInfo([FillerInputElement(SimplestModel.INPUT_SIZE)])) @@ -933,179 +916,3 @@ def test_insert_hook_after_parameter(): assert hook.forward_calls_counter == 1 assert torch.sum(result.nonzero()) > 0 assert torch.sum(result_with_hook.nonzero()) == 0 - - -@pytest.mark.parametrize( - "target_type, target_node_name, input_port_id", - [ - (TargetType.OPERATOR_PRE_HOOK, "/nncf_model_output_0", 0), - (TargetType.OPERATOR_POST_HOOK, "/nncf_model_input_0", 0), - (TargetType.PRE_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), - (TargetType.POST_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), - ], -) -class TestHookHandles: - class TestHook(torch.nn.Module): - def __init__(self): - super().__init__() - self._p = torch.nn.Parameter(torch.zeros((1,))) - - def forward(self, x): - return x + self._p - - @staticmethod - def _prepare_hook_handles_test( - target_type: TargetType, target_node_name: str, input_port_id: int - ) -> Tuple[NNCFNetwork, PTInsertionPoint, Callable[[List[HookHandle]], None]]: - model = SimplestModel() - example_input = torch.ones(SimplestModel.INPUT_SIZE) - input_info = ExampleInputInfo.from_example_input(example_input) - nncf_model = NNCFNetwork(model, input_info) - - node_name_vs_address = nncf_model.nncf.get_node_to_op_address_mapping() - ip = PTInsertionPoint(target_type, node_name_vs_address[target_node_name], input_port_id=input_port_id) - - checker = HookChecker(nncf_model, "conv") - - def _check(ref_hooks_): - checker.clear() - checker.add_ref(ref_hooks_, target_type, target_node_name, input_port_id) - checker.check_with_reference() - - return nncf_model, ip, _check - - def test_temporary_insert_at_point_by_hook_group_name( - self, target_type: TargetType, target_node_name: str, input_port_id: int - ): - nncf_model, ip, _check = self._prepare_hook_handles_test(target_type, target_node_name, input_port_id) - permanent_hook = self.TestHook() - TEMPORARY_HOOK_GROUP_NAME = "tmp" - # Make temporary hook a ref to the permanent hook - # to check tmp hooks are not removed by their id() - temporary_hook = permanent_hook - nncf_model.nncf.insert_at_point(ip, permanent_hook) - ref_hooks = [permanent_hook] - _check(ref_hooks) - - for _ in range(2): - temporary_hook = self.TestHook() - nncf_model.nncf.insert_at_point(ip, temporary_hook, TEMPORARY_HOOK_GROUP_NAME) - ref_hooks.append(temporary_hook) - _check(ref_hooks) - - nncf_model.nncf.insert_at_point(ip, permanent_hook) - ref_hooks.append(permanent_hook) - _check(ref_hooks) - - nncf_model.nncf.remove_hooks_group(TEMPORARY_HOOK_GROUP_NAME) - del ref_hooks[-2] - _check(ref_hooks) - assert not nncf_model.nncf._groups_vs_hooks_handlers[TEMPORARY_HOOK_GROUP_NAME] - - def test_insert_at_point_hook_handles(self, target_type: TargetType, target_node_name: str, input_port_id: int): - nncf_model, ip, _check = self._prepare_hook_handles_test(target_type, target_node_name, input_port_id) - permanent_hook = self.TestHook() - # Make temporary hook a ref to the permanent hook - # to check tmp hooks are not removed by their id() - temporary_hook = permanent_hook - tmp_hh = [] - nncf_model.nncf.insert_at_point(ip, permanent_hook) - - ref_hooks = [permanent_hook] - _check(ref_hooks) - - for _ in range(2): - temporary_hook = self.TestHook() - tmp_hh.append(nncf_model.nncf.insert_at_point(ip, temporary_hook)) - ref_hooks.append(temporary_hook) - _check(ref_hooks) - - nncf_model.nncf.insert_at_point(ip, permanent_hook) - ref_hooks.append(permanent_hook) - _check(ref_hooks) - - for hh in tmp_hh: - hh.remove() - - del ref_hooks[-2] - _check(ref_hooks) - - -@pytest.mark.parametrize("target_type", TwoConvTestModel.AVAILABLE_TARGET_TYPES) -@pytest.mark.parametrize("command_builder", TwoConvTestModel.get_command_builders()) -def test_get_applied_modification_commands(command_builder, target_type): - command = command_builder(target_type, TransformationPriority.DEFAULT_PRIORITY) - if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ - TargetType.PRE_LAYER_OPERATION, - TargetType.POST_LAYER_OPERATION, - ]: - pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") - - model = TwoConvTestModel() - nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) - model_tranformer = PTModelTransformer(nncf_model) - - layout = PTTransformationLayout() - layout.register(command) - model_tranformer.transform(layout) - - applied_commands = nncf_model.nncf.get_applied_transformation_layout() - - assert len(applied_commands.transformations) == 1 - applied_command = applied_commands.transformations[0] - are_commands_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) - - -@pytest.mark.parametrize("target_type", TwoConvTestModel.AVAILABLE_TARGET_TYPES) -@pytest.mark.parametrize( - "command_builder,command_type", tuple(zip(TwoConvTestModel.get_command_builders(), TwoConvTestModel.COMMAND_TYPES)) -) -def test_priority_of_get_applied_modification_commands(command_builder, target_type, command_type): - layout = PTTransformationLayout() - commands = dict() - for priority in (0, 3, 2, 4, 1): - if command_type is PTSharedFnInsertionCommand: - command = command_builder(target_type, priority, op_unique_name=f"UNIQUE_NAME_{priority}") - else: - command = command_builder(target_type, priority) - layout.register(command) - commands[priority] = command - else: - if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ - TargetType.PRE_LAYER_OPERATION, - TargetType.POST_LAYER_OPERATION, - ]: - pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") - - model = TwoConvTestModel() - nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) - model_tranformer = PTModelTransformer(nncf_model) - - model_tranformer.transform(layout) - - applied_commands = nncf_model.nncf.get_applied_transformation_layout() - assert len(applied_commands.transformations) == len(commands) - for applied_command in applied_commands.transformations: - command = commands[applied_command.priority] - are_commands_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) - - -def test_all_possible_combinations_of_commands_for_get_applied_commands(): - dummy_state = "DummyState" - commands = TwoConvTestModel.get_all_available_commands(dummy_state, skip_model_transformer_unsupported=True) - - model = TwoConvTestModel() - nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) - model_tranformer = PTModelTransformer(nncf_model) - - model_tranformer.transform(commands) - - applied_commands = nncf_model.nncf.get_applied_transformation_layout() - assert len(applied_commands.transformations) == len(commands.transformations) - for command in commands.transformations: - eq_commands = ( - are_commands_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) - for applied_command in applied_commands.transformations - ) - if sum(map(int, eq_commands)) != 1: - raise RuntimeError(f"Command {command} has no pair in recovered commands") diff --git a/tests/torch/test_api_behavior.py b/tests/torch/test_api_behavior.py index 9e18479e94e..565eaa7e86b 100644 --- a/tests/torch/test_api_behavior.py +++ b/tests/torch/test_api_behavior.py @@ -24,7 +24,7 @@ from tests.torch.helpers import OnesDatasetMock from tests.torch.helpers import TwoConvTestModel from tests.torch.helpers import create_compressed_model_and_algo_for_test -from tests.torch.test_nncf_network import SimplestModel +from tests.torch.nncf_network.helpers import SimplestModel INPUT_SAMPLE_SIZE = [1, 1, 4, 4] CONFIG_WITH_ALL_INIT_TYPES = {