From eea2e97bbc5ef59476f2437e7bf0a8ba380ab3fe Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 22 Mar 2024 14:50:25 +0100 Subject: [PATCH 1/9] [Torch] NNCFNetwork.get_applied_transformation_layout --- nncf/torch/graph/graph.py | 7 ++ nncf/torch/nncf_network.py | 127 ++++++++++++++++++++++ tests/torch/test_nncf_network.py | 179 +++++++++++++++++++++++++++++++ 3 files changed, 313 insertions(+) diff --git a/nncf/torch/graph/graph.py b/nncf/torch/graph/graph.py index 1a651f9a363..63637f759b0 100644 --- a/nncf/torch/graph/graph.py +++ b/nncf/torch/graph/graph.py @@ -60,6 +60,13 @@ def get_op_nodes_in_scope(self, scope: Scope) -> List[NNCFNode]: matching_graph_op_nodes.extend(nodes_in_module) return matching_graph_op_nodes + def get_op_node_in_scope(self, scope: Scope) -> List[NNCFNode]: + for scope_str, nodes_in_module in self._layer_name_vs_shared_nodes.items(): + module_scope = Scope.from_str(scope_str) + if module_scope == scope: + return nodes_in_module + return [] + def get_scope_by_node_name(self, node_name: NNCFNodeName) -> Scope: matches = [] for node_id, scope_str in self._node_ids_vs_layer_names.items(): diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index a27d338a77a..c06fd82b0a7 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -41,6 +41,7 @@ from nncf.common.utils.debug import is_debug from nncf.torch.debug import CombinedDebugInterface from nncf.torch.debug import debuggable_forward +from nncf.torch.dynamic_graph.context import PreHookId from nncf.torch.dynamic_graph.context import TracingContext from nncf.torch.dynamic_graph.graph import DynamicGraph from nncf.torch.dynamic_graph.graph import ShapeIgnoringTensorMetaComparator @@ -60,6 +61,7 @@ from nncf.torch.dynamic_graph.wrappers import wrap_module_call from nncf.torch.dynamic_graph.wrappers import wrap_parameters from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME +from nncf.torch.external_hook import ExternalOpCallHook from nncf.torch.graph.graph import PTNNCFGraph from nncf.torch.graph.graph_builder import GraphBuilder from nncf.torch.graph.graph_builder import GraphConverter @@ -67,9 +69,13 @@ from nncf.torch.graph.operator_metatypes import PTSplitMetatype from nncf.torch.graph.transformations.commands import DEFAULT_HOOKS_GROUP_NAME from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.knowledge_distillation.knowledge_distillation_handler import KnowledgeDistillationLossHandler from nncf.torch.layer_utils import _NNCFModuleMixin +from nncf.torch.module_operations import UpdateWeight from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME from nncf.torch.utils import compute_FLOPs_hook @@ -778,6 +784,116 @@ def _collect_eval_op_scopes(self, model: nn.Module, dummy_forward_fn: Callable) result.append(scope_in_model) return result + def get_applied_transformation_layout(self) -> PTTransformationLayout: + """ + Collects all hooks applied to the NNCFNetwork, converts them to insertion commands + and returns in PTTransformationLayout format. Default hooks group name is used in + recovered commands, so hooks group names specified diring the model modification + become outdated. + + :return: Transformation layout with all commands applied to the NNCFNetwork. + """ + + def _create_pt_insert_command(module, target_type, target_node_name, priority, input_port_id): + target_point = PTTargetPoint( + target_type=target_type, target_node_name=target_node_name, input_port_id=input_port_id + ) + return PTInsertionCommand(point=target_point, fn=module, priority=priority) + + def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str): + assert hasattr( + self, hook._storage_name + ), f"Storage name {hook._storage_name} is not registered. Info: {info}" + assert hook._storage_key in getattr( + self, hook._storage_name + ), f"Storage key {hook._storage_key} is not registered. Info: {info}" + + context_hooks = defaultdict(lambda: defaultdict(list)) + transformation_layout = PTTransformationLayout() + nncf_graph = self.get_graph() + nncf_node_names_map = self.get_op_address_to_op_name_map() + + # Collect pre/post layer and op with weights insertion commands + for nncf_module, module_scope in self.get_nncf_modules().items(): + for ops, target_type in ( + (nncf_module.pre_ops, TargetType.PRE_LAYER_OPERATION), + (nncf_module.post_ops, TargetType.POST_LAYER_OPERATION), + ): + for priority, module in enumerate(ops.values()): + nodes_in_scope = nncf_graph.get_op_node_in_scope(module_scope) + assert len(nodes_in_scope) == 1 + nncf_node = nodes_in_scope[0] + if isinstance(module, UpdateWeight): + target_type = TargetType.OPERATION_WITH_WEIGHTS + module = module.op + if not isinstance(module, ExternalOpCallHook): + command = _create_pt_insert_command(module, target_type, nncf_node.node_name, priority, None) + transformation_layout.register(command) + continue + + info = f"TargetType: {target_type}, nncf node name: {nncf_node.node_name}," + f" priority: {priority}, fn: {module}" + _check_external_call_hook_is_valid(module, info) + + context_hooks[module._storage_name][module._storage_key].append( + (target_type, nncf_node.node_name, priority, module, None) + ) + + # Collect all pre/post hooks commands + for ops, target_type in ( + (self._compressed_context._pre_hooks, TargetType.OPERATOR_PRE_HOOK), + (self._compressed_context._post_hooks, TargetType.OPERATOR_POST_HOOK), + ): + for op_address, hooks in ops.items(): + if isinstance(op_address, PreHookId): + input_port_id = op_address.input_port_id + op_address = op_address.op_address + else: + input_port_id = None + for priority, fn in enumerate(hooks.values()): + target_node_names = nncf_node_names_map[op_address] + assert len(target_node_names) == 1 + target_node_name = target_node_names[0] + + if not isinstance(fn, ExternalOpCallHook): + command = _create_pt_insert_command(fn, target_type, target_node_name, priority, input_port_id) + transformation_layout.register(command) + continue + + info = f"TargetType: {target_type}, op_address: {op_address}, priority: {priority}, fn: {fn}" + _check_external_call_hook_is_valid(fn, info) + + context_hooks[fn._storage_name][fn._storage_key].append( + (target_type, target_node_name, priority, fn, input_port_id) + ) + + # Create shared fn insertion commands according to external hooks collected from + # pre/post layer, pre/post hooks and op with weights target points. + for module_type_name, storage in context_hooks.items(): + for storage_key, call_hook_list_info in storage.items(): + compression_module = getattr(self, module_type_name)[storage_key] + target_points = [] + for target_type, target_node_name, priority, fn, input_port_id in call_hook_list_info: + target_points.append(PTTargetPoint(target_type, target_node_name, input_port_id=input_port_id)) + + if module_type_name == EXTERNAL_QUANTIZERS_STORAGE_NAME: + module_type = ExtraCompressionModuleType.EXTERNAL_QUANTIZER + elif module_type_name == EXTERNAL_OP_STORAGE_NAME: + module_type = ExtraCompressionModuleType.EXTERNAL_OP + else: + raise RuntimeError(f"Module type {module_type_name} is not supported") + + command = PTSharedFnInsertionCommand( + target_points=target_points, + fn=compression_module, + op_unique_name=storage_key, + compression_module_type=module_type, + priority=priority, + ) + transformation_layout.register(command) + + return transformation_layout + def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress]: """ Returns map of NNCFGraph node names vs DynamicGraph operation addresses. @@ -796,6 +912,17 @@ def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress] retval[nncf_node.node_name] = op_address return retval + def get_op_address_to_op_name_map(self) -> Dict[OperationAddress, NNCFNodeName]: + """ + Returns map of DynamicGraph operation addresses vs NNCFGraph node names. + + :return: DynamicGraph operation addresses vs NNCFGraph node names. + """ + retval = defaultdict(list) + for nncf_node_name, op_address in self.get_node_to_op_address_mapping().items(): + retval[op_address].append(nncf_node_name) + return retval + def set_compression_controller(self, ctrl: CompressionAlgorithmController): self.compression_controller = ctrl diff --git a/tests/torch/test_nncf_network.py b/tests/torch/test_nncf_network.py index c4da4be8c82..f282c95887a 100644 --- a/tests/torch/test_nncf_network.py +++ b/tests/torch/test_nncf_network.py @@ -27,6 +27,7 @@ from nncf.common.graph import NNCFNode from nncf.common.graph.operator_metatypes import UnknownMetatype from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationPriority from nncf.common.hook_handle import HookHandle from nncf.torch import register_module from nncf.torch.dynamic_graph.io_handling import ExampleInputInfo @@ -40,9 +41,14 @@ from nncf.torch.graph.operator_metatypes import PTConv2dMetatype from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.layer_utils import _NNCFModuleMixin from nncf.torch.layers import NNCFConv2d from nncf.torch.model_creation import wrap_model +from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.nncf_network import PTInsertionPoint @@ -1024,3 +1030,176 @@ def test_insert_at_point_hook_handles(self, target_type: TargetType, target_node del ref_hooks[-2] _check(ref_hooks) + + +class DummyOpWithState(torch.nn.Module): + def __init__(self, state: str): + super().__init__() + self._state = state + + def __call__(self, *args): + if len(args) == 1: + return args[0] + # To work correctly with + # TargetType.PRE_LAYER_OPERATION + # TargetType.POST_LAYER_OPERATION + return None + + def get_state(self): + return self._state.copy() + + @classmethod + def from_state(cls, state: str): + return cls(state) + + +TWO_CONV_MODEL_NODES_NAMES = [ + "TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0", + "TwoConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0", +] + + +def _create_pt_insertion_command( + target_type: TargetType, priority: TransformationPriority, group: str = "default_group" +): + target_point = PTTargetPoint( + target_type=target_type, target_node_name=TWO_CONV_MODEL_NODES_NAMES[0], input_port_id=0 + ) + fn = DummyOpWithState("DUMMY_STATE") + return PTInsertionCommand(point=target_point, fn=fn, priority=priority, hooks_group_name=group) + + +def _create_pt_shared_fn_insertion_command( + target_type: TargetType, + priority: TransformationPriority, + compression_module_type: ExtraCompressionModuleType, + group: str = "default_group", + op_unique_name: str = "UNIQUE_NAME", +): + target_points = [] + + for node_name in TWO_CONV_MODEL_NODES_NAMES: + target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0)) + fn = DummyOpWithState("DUMMY_STATE") + return PTSharedFnInsertionCommand( + target_points=target_points, + fn=fn, + compression_module_type=compression_module_type, + op_unique_name=op_unique_name, + priority=priority, + hooks_group_name=group, + ) + + +@pytest.mark.parametrize( + "target_type", + ( + TargetType.OPERATION_WITH_WEIGHTS, + TargetType.OPERATOR_PRE_HOOK, + TargetType.OPERATOR_POST_HOOK, + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ), +) +@pytest.mark.parametrize( + "command_builder,command_type", + ( + (_create_pt_insertion_command, PTInsertionCommand), + ( + functools.partial( + _create_pt_shared_fn_insertion_command, compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP + ), + PTSharedFnInsertionCommand, + ), + ( + functools.partial( + _create_pt_shared_fn_insertion_command, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, + ), + PTSharedFnInsertionCommand, + ), + ), +) +class TestGetAppliedModificationCommands: + def test_get_applied_modification_commands(self, command_builder, target_type, command_type): + command = command_builder(target_type, TransformationPriority.DEFAULT_PRIORITY) + if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: + pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") + + model = TwoConvTestModel() + nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) + model_tranformer = PTModelTransformer(nncf_model) + + layout = PTTransformationLayout() + layout.register(command) + model_tranformer.transform(layout) + + applied_commands = nncf_model.nncf.get_applied_transformation_layout() + + assert len(applied_commands.transformations) == 1 + applied_command = applied_commands.transformations[0] + self._check_commands_are_equal_except_priority_and_hooks_group(command, applied_command) + + def test_priority_of_get_applied_modification_commands(self, command_builder, target_type, command_type): + layout = PTTransformationLayout() + commands = dict() + for priority in (0, 3, 2, 4, 1): + if command_type is PTSharedFnInsertionCommand: + command = command_builder(target_type, priority, op_unique_name=f"UNIQUE_NAME_{priority}") + else: + command = command_builder(target_type, priority) + layout.register(command) + commands[priority] = command + else: + if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: + pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") + + model = TwoConvTestModel() + nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) + model_tranformer = PTModelTransformer(nncf_model) + + model_tranformer.transform(layout) + + applied_commands = nncf_model.nncf.get_applied_transformation_layout() + assert len(applied_commands.transformations) == len(commands) + for applied_command in applied_commands.transformations: + command = commands[applied_command.priority] + self._check_commands_are_equal_except_priority_and_hooks_group(command, applied_command) + + @staticmethod + def _target_points_are_equal(tp_original: PTTargetPoint, tp_recovered: PTTargetPoint): + if tp_original != tp_recovered: + return False + if tp_original.target_type == TargetType.OPERATOR_PRE_HOOK: + return tp_original.input_port_id == tp_recovered.input_port_id + return True + + @staticmethod + def _check_commands_are_equal_except_priority_and_hooks_group(command, applied_command): + assert type(applied_command) is type(command) + # Check reference to functions are equal. + # Important for the priority check + assert applied_command.fn is command.fn + ### TODO: map hooks group name + # assert applied_command.hooks_group_name == command.hooks_group_name + + if isinstance(applied_command, PTInsertionCommand): + assert TestGetAppliedModificationCommands._target_points_are_equal( + command.target_point, applied_command.target_point + ) + elif isinstance(applied_command, PTSharedFnInsertionCommand): + all( + TestGetAppliedModificationCommands._target_points_are_equal(a, b) + for a, b in zip(command.target_points, applied_command.target_points) + ) + assert applied_command.target_points == command.target_points + assert applied_command.op_name == command.op_name + assert applied_command.compression_module_type == command.compression_module_type + else: + raise RuntimeError() From e9654174cd4d3c062551e293ff3cd40380cef332 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Wed, 27 Mar 2024 15:32:16 +0100 Subject: [PATCH 2/9] Fix target_type detection bug --- nncf/torch/nncf_network.py | 11 +- tests/torch/helpers.py | 183 ++++++++++++++++++++++++++ tests/torch/test_nncf_network.py | 212 +++++++++---------------------- 3 files changed, 249 insertions(+), 157 deletions(-) diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index c06fd82b0a7..d02f4038d31 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -823,20 +823,23 @@ def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str): nodes_in_scope = nncf_graph.get_op_node_in_scope(module_scope) assert len(nodes_in_scope) == 1 nncf_node = nodes_in_scope[0] + command_target_type = target_type if isinstance(module, UpdateWeight): - target_type = TargetType.OPERATION_WITH_WEIGHTS + command_target_type = TargetType.OPERATION_WITH_WEIGHTS module = module.op if not isinstance(module, ExternalOpCallHook): - command = _create_pt_insert_command(module, target_type, nncf_node.node_name, priority, None) + command = _create_pt_insert_command( + module, command_target_type, nncf_node.node_name, priority, None + ) transformation_layout.register(command) continue - info = f"TargetType: {target_type}, nncf node name: {nncf_node.node_name}," + info = f"TargetType: {command_target_type}, nncf node name: {nncf_node.node_name}," f" priority: {priority}, fn: {module}" _check_external_call_hook_is_valid(module, info) context_hooks[module._storage_name][module._storage_key].append( - (target_type, nncf_node.node_name, priority, module, None) + (command_target_type, nncf_node.node_name, priority, module, None) ) # Collect all pre/post hooks commands diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index 3dfe3a3df7e..be02bdec415 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -8,7 +8,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import contextlib +import functools +import itertools import numbers from abc import ABC from abc import abstractmethod @@ -29,6 +32,8 @@ import nncf from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.layout import TransformationLayout from nncf.config import NNCFConfig from nncf.config.extractors import extract_algorithm_names from nncf.config.structures import BNAdaptationInitArgs @@ -38,8 +43,13 @@ from nncf.torch.dynamic_graph.io_handling import FillerInputInfo from nncf.torch.dynamic_graph.operation_address import OperationAddress from nncf.torch.dynamic_graph.scope import Scope +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.initialization import PTInitializingDataLoader from nncf.torch.initialization import register_default_init_args +from nncf.torch.layer_utils import COMPRESSION_MODULES from nncf.torch.layers import NNCF_MODULES_MAP from nncf.torch.model_creation import create_compressed_model from nncf.torch.module_operations import UpdateWeight @@ -172,6 +182,12 @@ def nz_bias_num(self): class TwoConvTestModel(nn.Module): + INPUT_SHAPE = [1, 1, 4, 4] + NNCF_CONV_NODES_NAMES = [ + "TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0", + "TwoConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0", + ] + def __init__(self): super().__init__() self.features = [] @@ -198,6 +214,113 @@ def nz_weights_num(self): def nz_bias_num(self): return 2 + @staticmethod + def create_pt_insertion_command( + target_type: TargetType, priority: TransformationPriority, fn=None, group: str = "default_group" + ): + target_point = PTTargetPoint( + target_type=target_type, target_node_name=TwoConvTestModel.NNCF_CONV_NODES_NAMES[0], input_port_id=0 + ) + if fn is None: + fn = DummyOpWithState("DUMMY_STATE") + return PTInsertionCommand(point=target_point, fn=fn, priority=priority, hooks_group_name=group) + + @staticmethod + def create_pt_shared_fn_insertion_command( + target_type: TargetType, + priority: TransformationPriority, + compression_module_type: ExtraCompressionModuleType, + fn=None, + group: str = "default_group", + op_unique_name: str = "UNIQUE_NAME", + ): + target_points = [] + + for node_name in TwoConvTestModel.NNCF_CONV_NODES_NAMES: + target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0)) + if fn is None: + fn = DummyOpWithState("DUMMY_STATE") + return PTSharedFnInsertionCommand( + target_points=target_points, + fn=fn, + compression_module_type=compression_module_type, + op_unique_name=op_unique_name, + priority=priority, + hooks_group_name=group, + ) + + AVAILABLE_TARGET_TYPES = ( + TargetType.OPERATION_WITH_WEIGHTS, + TargetType.OPERATOR_PRE_HOOK, + TargetType.OPERATOR_POST_HOOK, + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ) + + @staticmethod + def get_command_builders(): + return ( + TwoConvTestModel.create_pt_insertion_command, + functools.partial( + TwoConvTestModel.create_pt_shared_fn_insertion_command, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP, + ), + functools.partial( + TwoConvTestModel.create_pt_shared_fn_insertion_command, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, + ), + ) + + COMMAND_TYPES = [PTInsertionCommand, PTSharedFnInsertionCommand, PTSharedFnInsertionCommand] + PRIORITIES = (TransformationPriority.QUANTIZATION_PRIORITY, TransformationPriority.QUANTIZATION_PRIORITY.value + 1) + + @classmethod + def get_all_available_commands( + cls, dummy_op_state, skip_model_transformer_unsupported=False + ) -> TransformationLayout: + """ + Returns all possible commands to insert: + all target types x all command class x all compression module types x different priorities. + """ + layout = TransformationLayout() + for idx, (target_type, (command_builder, command_type), priority) in enumerate( + itertools.product( + cls.AVAILABLE_TARGET_TYPES, zip(cls.get_command_builders(), cls.COMMAND_TYPES), cls.PRIORITIES + ) + ): + if command_type is PTSharedFnInsertionCommand: + if skip_model_transformer_unsupported and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: + continue + command = cls._create_command( + command_builder, target_type, priority, dummy_op_state, op_unique_name=f"UNIQUE_NAME_{idx}" + ) + else: + command = cls._create_command(command_builder, target_type, priority, dummy_op_state) + + layout.register(command) + return layout + + @staticmethod + def _create_command(command_builder, target_type, priority, dummy_op_state, op_unique_name=None): + group_name = "CUSTOM_HOOKS_GROUP_NAME" + + if DummyOpWithState.__name__ not in COMPRESSION_MODULES.registry_dict: + registered_dummy_op_cls = COMPRESSION_MODULES.register()(DummyOpWithState) + else: + registered_dummy_op_cls = DummyOpWithState + dummy_op = registered_dummy_op_cls(dummy_op_state) + if op_unique_name is None: + command = command_builder(target_type, priority, fn=dummy_op, group=group_name) + else: + command = command_builder( + target_type, priority, fn=dummy_op, group=group_name, op_unique_name=op_unique_name + ) + + return command + class LeNet(nn.Module): INPUT_SIZE = 1, 32, 32 @@ -228,6 +351,66 @@ def num_flat_features(self, x): return num_features +class DummyOpWithState(torch.nn.Module): + def __init__(self, state: str): + super().__init__() + self._state = state + + def __call__(self, *args): + if len(args) == 1: + return args[0] + # To work correctly with + # TargetType.PRE_LAYER_OPERATION + # TargetType.POST_LAYER_OPERATION + return None + + def get_state(self): + return self._state + + @classmethod + def from_state(cls, state: str): + return cls(state) + + +def target_points_are_equal(tp_original: PTTargetPoint, tp_recovered: PTTargetPoint): + if tp_original != tp_recovered: + return False + if tp_original.target_type == TargetType.OPERATOR_PRE_HOOK: + return tp_original.input_port_id == tp_recovered.input_port_id + return True + + +def are_commands_equal( + command, applied_command, check_priority: bool = True, check_hooks_group_name: bool = True, check_fn_ref=True +): + if type(applied_command) is not type(command): + return False + + # Check reference to functions are equal. + if check_fn_ref and applied_command.fn is not command.fn: + return False + if check_hooks_group_name and applied_command.hooks_group_name != command.hooks_group_name: + return False + if check_priority and applied_command.priority != command.priority: + return False + + if isinstance(applied_command, PTInsertionCommand): + if not target_points_are_equal(command.target_point, applied_command.target_point): + return False + elif isinstance(applied_command, PTSharedFnInsertionCommand): + if not all(target_points_are_equal(a, b) for a, b in zip(command.target_points, applied_command.target_points)): + return False + if ( + applied_command.target_points != command.target_points + or applied_command.op_name != command.op_name + or applied_command.compression_module_type != command.compression_module_type + ): + return False + else: + raise RuntimeError() + return True + + class SharedConv(nn.Module): INPUT_SIZE = [1, 1, 4, 4] diff --git a/tests/torch/test_nncf_network.py b/tests/torch/test_nncf_network.py index f282c95887a..331d8c8e951 100644 --- a/tests/torch/test_nncf_network.py +++ b/tests/torch/test_nncf_network.py @@ -41,9 +41,7 @@ from nncf.torch.graph.operator_metatypes import PTConv2dMetatype from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType -from nncf.torch.graph.transformations.commands import PTInsertionCommand from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand -from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.layer_utils import _NNCFModuleMixin from nncf.torch.layers import NNCFConv2d @@ -58,6 +56,7 @@ from tests.torch.helpers import BasicConvTestModel from tests.torch.helpers import HookChecker from tests.torch.helpers import TwoConvTestModel +from tests.torch.helpers import are_commands_equal from tests.torch.helpers import check_correct_nncf_modules_replacement from tests.torch.helpers import create_compressed_model_and_algo_for_test from tests.torch.helpers import register_bn_adaptation_init_args @@ -1032,174 +1031,81 @@ def test_insert_at_point_hook_handles(self, target_type: TargetType, target_node _check(ref_hooks) -class DummyOpWithState(torch.nn.Module): - def __init__(self, state: str): - super().__init__() - self._state = state - - def __call__(self, *args): - if len(args) == 1: - return args[0] - # To work correctly with - # TargetType.PRE_LAYER_OPERATION - # TargetType.POST_LAYER_OPERATION - return None +@pytest.mark.parametrize("target_type", TwoConvTestModel.AVAILABLE_TARGET_TYPES) +@pytest.mark.parametrize("command_builder", TwoConvTestModel.get_command_builders()) +def test_get_applied_modification_commands(command_builder, target_type): + command = command_builder(target_type, TransformationPriority.DEFAULT_PRIORITY) + if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: + pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") - def get_state(self): - return self._state.copy() + model = TwoConvTestModel() + nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) + model_tranformer = PTModelTransformer(nncf_model) - @classmethod - def from_state(cls, state: str): - return cls(state) + layout = PTTransformationLayout() + layout.register(command) + model_tranformer.transform(layout) + applied_commands = nncf_model.nncf.get_applied_transformation_layout() -TWO_CONV_MODEL_NODES_NAMES = [ - "TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0", - "TwoConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0", -] - - -def _create_pt_insertion_command( - target_type: TargetType, priority: TransformationPriority, group: str = "default_group" -): - target_point = PTTargetPoint( - target_type=target_type, target_node_name=TWO_CONV_MODEL_NODES_NAMES[0], input_port_id=0 - ) - fn = DummyOpWithState("DUMMY_STATE") - return PTInsertionCommand(point=target_point, fn=fn, priority=priority, hooks_group_name=group) - - -def _create_pt_shared_fn_insertion_command( - target_type: TargetType, - priority: TransformationPriority, - compression_module_type: ExtraCompressionModuleType, - group: str = "default_group", - op_unique_name: str = "UNIQUE_NAME", -): - target_points = [] - - for node_name in TWO_CONV_MODEL_NODES_NAMES: - target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0)) - fn = DummyOpWithState("DUMMY_STATE") - return PTSharedFnInsertionCommand( - target_points=target_points, - fn=fn, - compression_module_type=compression_module_type, - op_unique_name=op_unique_name, - priority=priority, - hooks_group_name=group, - ) + assert len(applied_commands.transformations) == 1 + applied_command = applied_commands.transformations[0] + are_commands_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) +@pytest.mark.parametrize("target_type", TwoConvTestModel.AVAILABLE_TARGET_TYPES) @pytest.mark.parametrize( - "target_type", - ( - TargetType.OPERATION_WITH_WEIGHTS, - TargetType.OPERATOR_PRE_HOOK, - TargetType.OPERATOR_POST_HOOK, - TargetType.PRE_LAYER_OPERATION, - TargetType.POST_LAYER_OPERATION, - ), -) -@pytest.mark.parametrize( - "command_builder,command_type", - ( - (_create_pt_insertion_command, PTInsertionCommand), - ( - functools.partial( - _create_pt_shared_fn_insertion_command, compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP - ), - PTSharedFnInsertionCommand, - ), - ( - functools.partial( - _create_pt_shared_fn_insertion_command, - compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, - ), - PTSharedFnInsertionCommand, - ), - ), + "command_builder,command_type", tuple(zip(TwoConvTestModel.get_command_builders(), TwoConvTestModel.COMMAND_TYPES)) ) -class TestGetAppliedModificationCommands: - def test_get_applied_modification_commands(self, command_builder, target_type, command_type): - command = command_builder(target_type, TransformationPriority.DEFAULT_PRIORITY) +def test_priority_of_get_applied_modification_commands(command_builder, target_type, command_type): + layout = PTTransformationLayout() + commands = dict() + for priority in (0, 3, 2, 4, 1): + if command_type is PTSharedFnInsertionCommand: + command = command_builder(target_type, priority, op_unique_name=f"UNIQUE_NAME_{priority}") + else: + command = command_builder(target_type, priority) + layout.register(command) + commands[priority] = command + else: if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ TargetType.PRE_LAYER_OPERATION, TargetType.POST_LAYER_OPERATION, ]: pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") - model = TwoConvTestModel() - nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) - model_tranformer = PTModelTransformer(nncf_model) + model = TwoConvTestModel() + nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) + model_tranformer = PTModelTransformer(nncf_model) - layout = PTTransformationLayout() - layout.register(command) - model_tranformer.transform(layout) - - applied_commands = nncf_model.nncf.get_applied_transformation_layout() - - assert len(applied_commands.transformations) == 1 - applied_command = applied_commands.transformations[0] - self._check_commands_are_equal_except_priority_and_hooks_group(command, applied_command) - - def test_priority_of_get_applied_modification_commands(self, command_builder, target_type, command_type): - layout = PTTransformationLayout() - commands = dict() - for priority in (0, 3, 2, 4, 1): - if command_type is PTSharedFnInsertionCommand: - command = command_builder(target_type, priority, op_unique_name=f"UNIQUE_NAME_{priority}") - else: - command = command_builder(target_type, priority) - layout.register(command) - commands[priority] = command - else: - if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ - TargetType.PRE_LAYER_OPERATION, - TargetType.POST_LAYER_OPERATION, - ]: - pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") + model_tranformer.transform(layout) - model = TwoConvTestModel() - nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) - model_tranformer = PTModelTransformer(nncf_model) + applied_commands = nncf_model.nncf.get_applied_transformation_layout() + assert len(applied_commands.transformations) == len(commands) + for applied_command in applied_commands.transformations: + command = commands[applied_command.priority] + are_commands_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) - model_tranformer.transform(layout) - applied_commands = nncf_model.nncf.get_applied_transformation_layout() - assert len(applied_commands.transformations) == len(commands) - for applied_command in applied_commands.transformations: - command = commands[applied_command.priority] - self._check_commands_are_equal_except_priority_and_hooks_group(command, applied_command) +def test_all_possible_combinations_of_commands_for_get_applied_commands(): + dummy_state = "DummyState" + commands = TwoConvTestModel.get_all_available_commands(dummy_state, skip_model_transformer_unsupported=True) - @staticmethod - def _target_points_are_equal(tp_original: PTTargetPoint, tp_recovered: PTTargetPoint): - if tp_original != tp_recovered: - return False - if tp_original.target_type == TargetType.OPERATOR_PRE_HOOK: - return tp_original.input_port_id == tp_recovered.input_port_id - return True + model = TwoConvTestModel() + nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) + model_tranformer = PTModelTransformer(nncf_model) - @staticmethod - def _check_commands_are_equal_except_priority_and_hooks_group(command, applied_command): - assert type(applied_command) is type(command) - # Check reference to functions are equal. - # Important for the priority check - assert applied_command.fn is command.fn - ### TODO: map hooks group name - # assert applied_command.hooks_group_name == command.hooks_group_name - - if isinstance(applied_command, PTInsertionCommand): - assert TestGetAppliedModificationCommands._target_points_are_equal( - command.target_point, applied_command.target_point - ) - elif isinstance(applied_command, PTSharedFnInsertionCommand): - all( - TestGetAppliedModificationCommands._target_points_are_equal(a, b) - for a, b in zip(command.target_points, applied_command.target_points) - ) - assert applied_command.target_points == command.target_points - assert applied_command.op_name == command.op_name - assert applied_command.compression_module_type == command.compression_module_type - else: - raise RuntimeError() + model_tranformer.transform(commands) + + applied_commands = nncf_model.nncf.get_applied_transformation_layout() + assert len(applied_commands.transformations) == len(commands.transformations) + for command in commands.transformations: + eq_commands = ( + are_commands_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) + for applied_command in applied_commands.transformations + ) + if sum(map(int, eq_commands)) != 1: + raise RuntimeError(f"Command {command} has no pair in recovered commands") From d010da6164522dfcc9f49221b4fb961ba7871f11 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 11 Apr 2024 20:01:51 +0200 Subject: [PATCH 3/9] Comments --- nncf/torch/graph/graph.py | 6 +- nncf/torch/nncf_network.py | 49 ++++- tests/torch/helpers.py | 176 ++++------------ tests/torch/nncf_network/helpers.py | 194 +++++++++++++++++ .../test_get_applied_modifications.py | 162 ++++++++++++++ .../torch/nncf_network/test_hook_handlers.py | 119 +++++++++++ .../{ => nncf_network}/test_nncf_network.py | 197 +----------------- tests/torch/test_api_behavior.py | 2 +- 8 files changed, 558 insertions(+), 347 deletions(-) create mode 100644 tests/torch/nncf_network/helpers.py create mode 100644 tests/torch/nncf_network/test_get_applied_modifications.py create mode 100644 tests/torch/nncf_network/test_hook_handlers.py rename tests/torch/{ => nncf_network}/test_nncf_network.py (80%) diff --git a/nncf/torch/graph/graph.py b/nncf/torch/graph/graph.py index 63637f759b0..588866367c0 100644 --- a/nncf/torch/graph/graph.py +++ b/nncf/torch/graph/graph.py @@ -61,11 +61,7 @@ def get_op_nodes_in_scope(self, scope: Scope) -> List[NNCFNode]: return matching_graph_op_nodes def get_op_node_in_scope(self, scope: Scope) -> List[NNCFNode]: - for scope_str, nodes_in_module in self._layer_name_vs_shared_nodes.items(): - module_scope = Scope.from_str(scope_str) - if module_scope == scope: - return nodes_in_module - return [] + return self._layer_name_vs_shared_nodes[str(scope)] def get_scope_by_node_name(self, node_name: NNCFNodeName) -> Scope: matches = [] diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index d02f4038d31..01f3201e1e7 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -794,19 +794,19 @@ def get_applied_transformation_layout(self) -> PTTransformationLayout: :return: Transformation layout with all commands applied to the NNCFNetwork. """ - def _create_pt_insert_command(module, target_type, target_node_name, priority, input_port_id): - target_point = PTTargetPoint( - target_type=target_type, target_node_name=target_node_name, input_port_id=input_port_id - ) - return PTInsertionCommand(point=target_point, fn=module, priority=priority) - def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str): + """ + Check given external op call hook reference is correct. + + :param hook: External op call hook to check correctness. + :param info: Info to log in case op call hook references are broken. + """ assert hasattr( self, hook._storage_name ), f"Storage name {hook._storage_name} is not registered. Info: {info}" assert hook._storage_key in getattr( self, hook._storage_name - ), f"Storage key {hook._storage_key} is not registered. Info: {info}" + ), f"Key {hook._storage_key} is not registered in {hook._storage_name}. Info: {info}" context_hooks = defaultdict(lambda: defaultdict(list)) transformation_layout = PTTransformationLayout() @@ -828,14 +828,16 @@ def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str): command_target_type = TargetType.OPERATION_WITH_WEIGHTS module = module.op if not isinstance(module, ExternalOpCallHook): - command = _create_pt_insert_command( + command = create_pt_insertion_command( module, command_target_type, nncf_node.node_name, priority, None ) transformation_layout.register(command) continue - info = f"TargetType: {command_target_type}, nncf node name: {nncf_node.node_name}," - f" priority: {priority}, fn: {module}" + info = ( + f"TargetType: {command_target_type}, nncf node name: {nncf_node.node_name}," + f" priority: {priority}, fn: {module}" + ) _check_external_call_hook_is_valid(module, info) context_hooks[module._storage_name][module._storage_key].append( @@ -859,7 +861,9 @@ def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str): target_node_name = target_node_names[0] if not isinstance(fn, ExternalOpCallHook): - command = _create_pt_insert_command(fn, target_type, target_node_name, priority, input_port_id) + command = create_pt_insertion_command( + fn, target_type, target_node_name, priority, input_port_id + ) transformation_layout.register(command) continue @@ -1262,3 +1266,26 @@ def compression_module_type_to_attr_name(compression_module_type: ExtraCompressi if compression_module_type == ExtraCompressionModuleType.EXTERNAL_OP: return EXTERNAL_OP_STORAGE_NAME raise nncf.ValidationError("Unknown extra module type") + + +def create_pt_insertion_command( + module: torch.nn.Module, + target_type: TargetType, + target_node_name: str, + priority: int, + input_port_id: Optional[int], +) -> PTInsertionCommand: + """ + Creates a PTInsertionCommand. + + :param module: Torch module to insert. + :param target_type: Insertion command target type. + :param target_name: Insertion command target name. + :param priority: Insertion command priority. + :param input_port_id: Insertion command input port id. + :return: A PTInsertionCommand + """ + target_point = PTTargetPoint( + target_type=target_type, target_node_name=target_node_name, input_port_id=input_port_id + ) + return PTInsertionCommand(point=target_point, fn=module, priority=priority) diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index be02bdec415..4c40f218163 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -10,8 +10,6 @@ # limitations under the License. import contextlib -import functools -import itertools import numbers from abc import ABC from abc import abstractmethod @@ -32,8 +30,6 @@ import nncf from nncf.common.graph.transformations.commands import TargetType -from nncf.common.graph.transformations.commands import TransformationPriority -from nncf.common.graph.transformations.layout import TransformationLayout from nncf.config import NNCFConfig from nncf.config.extractors import extract_algorithm_names from nncf.config.structures import BNAdaptationInitArgs @@ -43,13 +39,11 @@ from nncf.torch.dynamic_graph.io_handling import FillerInputInfo from nncf.torch.dynamic_graph.operation_address import OperationAddress from nncf.torch.dynamic_graph.scope import Scope -from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.graph.transformations.commands import PTInsertionCommand from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.initialization import PTInitializingDataLoader from nncf.torch.initialization import register_default_init_args -from nncf.torch.layer_utils import COMPRESSION_MODULES from nncf.torch.layers import NNCF_MODULES_MAP from nncf.torch.model_creation import create_compressed_model from nncf.torch.module_operations import UpdateWeight @@ -183,10 +177,6 @@ def nz_bias_num(self): class TwoConvTestModel(nn.Module): INPUT_SHAPE = [1, 1, 4, 4] - NNCF_CONV_NODES_NAMES = [ - "TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0", - "TwoConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0", - ] def __init__(self): super().__init__() @@ -214,113 +204,6 @@ def nz_weights_num(self): def nz_bias_num(self): return 2 - @staticmethod - def create_pt_insertion_command( - target_type: TargetType, priority: TransformationPriority, fn=None, group: str = "default_group" - ): - target_point = PTTargetPoint( - target_type=target_type, target_node_name=TwoConvTestModel.NNCF_CONV_NODES_NAMES[0], input_port_id=0 - ) - if fn is None: - fn = DummyOpWithState("DUMMY_STATE") - return PTInsertionCommand(point=target_point, fn=fn, priority=priority, hooks_group_name=group) - - @staticmethod - def create_pt_shared_fn_insertion_command( - target_type: TargetType, - priority: TransformationPriority, - compression_module_type: ExtraCompressionModuleType, - fn=None, - group: str = "default_group", - op_unique_name: str = "UNIQUE_NAME", - ): - target_points = [] - - for node_name in TwoConvTestModel.NNCF_CONV_NODES_NAMES: - target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0)) - if fn is None: - fn = DummyOpWithState("DUMMY_STATE") - return PTSharedFnInsertionCommand( - target_points=target_points, - fn=fn, - compression_module_type=compression_module_type, - op_unique_name=op_unique_name, - priority=priority, - hooks_group_name=group, - ) - - AVAILABLE_TARGET_TYPES = ( - TargetType.OPERATION_WITH_WEIGHTS, - TargetType.OPERATOR_PRE_HOOK, - TargetType.OPERATOR_POST_HOOK, - TargetType.PRE_LAYER_OPERATION, - TargetType.POST_LAYER_OPERATION, - ) - - @staticmethod - def get_command_builders(): - return ( - TwoConvTestModel.create_pt_insertion_command, - functools.partial( - TwoConvTestModel.create_pt_shared_fn_insertion_command, - compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP, - ), - functools.partial( - TwoConvTestModel.create_pt_shared_fn_insertion_command, - compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, - ), - ) - - COMMAND_TYPES = [PTInsertionCommand, PTSharedFnInsertionCommand, PTSharedFnInsertionCommand] - PRIORITIES = (TransformationPriority.QUANTIZATION_PRIORITY, TransformationPriority.QUANTIZATION_PRIORITY.value + 1) - - @classmethod - def get_all_available_commands( - cls, dummy_op_state, skip_model_transformer_unsupported=False - ) -> TransformationLayout: - """ - Returns all possible commands to insert: - all target types x all command class x all compression module types x different priorities. - """ - layout = TransformationLayout() - for idx, (target_type, (command_builder, command_type), priority) in enumerate( - itertools.product( - cls.AVAILABLE_TARGET_TYPES, zip(cls.get_command_builders(), cls.COMMAND_TYPES), cls.PRIORITIES - ) - ): - if command_type is PTSharedFnInsertionCommand: - if skip_model_transformer_unsupported and target_type in [ - TargetType.PRE_LAYER_OPERATION, - TargetType.POST_LAYER_OPERATION, - ]: - continue - command = cls._create_command( - command_builder, target_type, priority, dummy_op_state, op_unique_name=f"UNIQUE_NAME_{idx}" - ) - else: - command = cls._create_command(command_builder, target_type, priority, dummy_op_state) - - layout.register(command) - return layout - - @staticmethod - def _create_command(command_builder, target_type, priority, dummy_op_state, op_unique_name=None): - group_name = "CUSTOM_HOOKS_GROUP_NAME" - - if DummyOpWithState.__name__ not in COMPRESSION_MODULES.registry_dict: - registered_dummy_op_cls = COMPRESSION_MODULES.register()(DummyOpWithState) - else: - registered_dummy_op_cls = DummyOpWithState - dummy_op = registered_dummy_op_cls(dummy_op_state) - if op_unique_name is None: - command = command_builder(target_type, priority, fn=dummy_op, group=group_name) - else: - command = command_builder( - target_type, priority, fn=dummy_op, group=group_name, op_unique_name=op_unique_name - ) - - return command - class LeNet(nn.Module): INPUT_SIZE = 1, 32, 32 @@ -372,38 +255,61 @@ def from_state(cls, state: str): return cls(state) -def target_points_are_equal(tp_original: PTTargetPoint, tp_recovered: PTTargetPoint): - if tp_original != tp_recovered: +def target_points_are_equal(tp_left: PTTargetPoint, tp_right: PTTargetPoint) -> bool: + """ + Returns True if given target points are equal and False elsewhere. + + :param tp_left: The first target point. + :param tp_right: The second target point. + :return: True if given target points are equal and False elsewhere. + """ + if tp_left != tp_right: return False - if tp_original.target_type == TargetType.OPERATOR_PRE_HOOK: - return tp_original.input_port_id == tp_recovered.input_port_id + if tp_left.target_type == TargetType.OPERATOR_PRE_HOOK: + return tp_left.input_port_id == tp_right.input_port_id return True -def are_commands_equal( - command, applied_command, check_priority: bool = True, check_hooks_group_name: bool = True, check_fn_ref=True -): - if type(applied_command) is not type(command): +def commands_are_equal( + command_left: Union[PTInsertionCommand, PTSharedFnInsertionCommand], + command_right: Union[PTInsertionCommand, PTSharedFnInsertionCommand], + check_priority: bool = True, + check_hooks_group_name: bool = True, + check_fn_ref=True, +) -> bool: + """ + Returns True if given commands are equal and False elsewhere. + + :param command_left: The first command. + :param command_right: The second command. + :param check_priority: Whether to check insertion priority or not. + :param check_hooks_group_name: Whether to check hooks group name or not. + :param check_fn_ref: Whether to check fn by reference or not. + :returns: True if given commands are equal and False elsewhere. + """ + if type(command_right) is not type(command_left): return False # Check reference to functions are equal. - if check_fn_ref and applied_command.fn is not command.fn: + if check_fn_ref and command_right.fn is not command_left.fn: return False - if check_hooks_group_name and applied_command.hooks_group_name != command.hooks_group_name: + if check_hooks_group_name and command_right.hooks_group_name != command_left.hooks_group_name: return False - if check_priority and applied_command.priority != command.priority: + if check_priority and command_right.priority != command_left.priority: return False - if isinstance(applied_command, PTInsertionCommand): - if not target_points_are_equal(command.target_point, applied_command.target_point): + if isinstance(command_right, PTInsertionCommand): + if not target_points_are_equal(command_left.target_point, command_right.target_point): return False - elif isinstance(applied_command, PTSharedFnInsertionCommand): - if not all(target_points_are_equal(a, b) for a, b in zip(command.target_points, applied_command.target_points)): + elif isinstance(command_right, PTSharedFnInsertionCommand): + if not all( + target_points_are_equal(a, b) for a, b in zip(command_left.target_points, command_right.target_points) + ): return False if ( - applied_command.target_points != command.target_points - or applied_command.op_name != command.op_name - or applied_command.compression_module_type != command.compression_module_type + command_right.target_points != command_left.target_points + or command_right.op_name != command_left.op_name + or command_right.compression_module_type != command_left.compression_module_type ): return False else: diff --git a/tests/torch/nncf_network/helpers.py b/tests/torch/nncf_network/helpers.py new file mode 100644 index 00000000000..a16dba2b263 --- /dev/null +++ b/tests/torch/nncf_network/helpers.py @@ -0,0 +1,194 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import itertools + +import torch + +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.layer_utils import COMPRESSION_MODULES +from tests.torch.helpers import DummyOpWithState + + +class SimplestModel(torch.nn.Module): + INPUT_SIZE = [1, 1, 32, 32] + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + return self.conv(x) + + +AVAILABLE_TARGET_TYPES = ( + TargetType.OPERATION_WITH_WEIGHTS, + TargetType.OPERATOR_PRE_HOOK, + TargetType.OPERATOR_POST_HOOK, + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, +) + + +class InsertionCommandBuilder: + """ + Contains methods which allows to build all possible commands + for the TwoConvTestModel + """ + + NNCF_CONV_NODES_NAMES = [ + "TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0", + "TwoConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0", + ] + CONV_NODES_NAMES = [ + "TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0", + "TwoConvTestModel/Sequential[features]/Sequential[1]/Conv2d[0]/conv2d_0", + ] + + TRACE_VS_NODE_NAMES = {True: CONV_NODES_NAMES, False: NNCF_CONV_NODES_NAMES} + + @classmethod + def create_pt_insertion_command( + cls, + target_type: TargetType, + priority: TransformationPriority, + trace_parameters: bool, + fn=None, + group: str = "default_group", + ): + target_point = PTTargetPoint( + target_type=target_type, target_node_name=cls.TRACE_VS_NODE_NAMES[trace_parameters][0], input_port_id=0 + ) + if fn is None: + fn = DummyOpWithState("DUMMY_STATE") + return PTInsertionCommand(point=target_point, fn=fn, priority=priority, hooks_group_name=group) + + @classmethod + def create_pt_shared_fn_insertion_command( + cls, + target_type: TargetType, + priority: TransformationPriority, + trace_parameters: bool, + compression_module_type: ExtraCompressionModuleType, + fn=None, + group: str = "default_group", + op_unique_name: str = "UNIQUE_NAME", + ): + target_points = [] + + for node_name in cls.TRACE_VS_NODE_NAMES[trace_parameters]: + target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0)) + if fn is None: + fn = DummyOpWithState("DUMMY_STATE") + return PTSharedFnInsertionCommand( + target_points=target_points, + fn=fn, + compression_module_type=compression_module_type, + op_unique_name=op_unique_name, + priority=priority, + hooks_group_name=group, + ) + + @staticmethod + def get_command_builders(): + return ( + InsertionCommandBuilder.create_pt_insertion_command, + functools.partial( + InsertionCommandBuilder.create_pt_shared_fn_insertion_command, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP, + ), + functools.partial( + InsertionCommandBuilder.create_pt_shared_fn_insertion_command, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, + ), + ) + + @classmethod + def get_command_builders_with_types(cls): + return tuple(zip(cls.get_command_builders(), cls.COMMAND_TYPES)) + + COMMAND_TYPES = [PTInsertionCommand, PTSharedFnInsertionCommand, PTSharedFnInsertionCommand] + PRIORITIES = (TransformationPriority.QUANTIZATION_PRIORITY, TransformationPriority.QUANTIZATION_PRIORITY.value + 1) + + @classmethod + def get_all_available_commands( + cls, dummy_op_state, trace_parameters, skip_model_transformer_unsupported=False + ) -> TransformationLayout: + """ + Returns all possible commands to insert: + all target types x all command class x all compression module types x different priorities. + """ + layout = TransformationLayout() + for idx, (target_type, (command_builder, command_type), priority) in enumerate( + itertools.product( + AVAILABLE_TARGET_TYPES, zip(cls.get_command_builders(), cls.COMMAND_TYPES), cls.PRIORITIES + ) + ): + if command_type is PTSharedFnInsertionCommand: + if skip_model_transformer_unsupported and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: + continue + command = cls._create_command( + command_builder, + target_type, + priority, + dummy_op_state, + op_unique_name=f"UNIQUE_NAME_{idx}", + trace_parameters=trace_parameters, + ) + else: + command = cls._create_command( + command_builder, target_type, priority, dummy_op_state, trace_parameters=trace_parameters + ) + + layout.register(command) + return layout + + @staticmethod + def _create_command( + command_builder, + target_type, + priority, + dummy_op_state, + trace_parameters, + op_unique_name=None, + ): + group_name = "CUSTOM_HOOKS_GROUP_NAME" + + if DummyOpWithState.__name__ not in COMPRESSION_MODULES.registry_dict: + registered_dummy_op_cls = COMPRESSION_MODULES.register()(DummyOpWithState) + else: + registered_dummy_op_cls = DummyOpWithState + dummy_op = registered_dummy_op_cls(dummy_op_state) + if op_unique_name is None: + command = command_builder( + target_type, priority, fn=dummy_op, group=group_name, trace_parameters=trace_parameters + ) + else: + command = command_builder( + target_type, + priority, + fn=dummy_op, + group=group_name, + op_unique_name=op_unique_name, + trace_parameters=trace_parameters, + ) + + return command diff --git a/tests/torch/nncf_network/test_get_applied_modifications.py b/tests/torch/nncf_network/test_get_applied_modifications.py new file mode 100644 index 00000000000..d19840821c1 --- /dev/null +++ b/tests/torch/nncf_network/test_get_applied_modifications.py @@ -0,0 +1,162 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.torch import wrap_model +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.graph.transformations.layout import PTTransformationLayout +from nncf.torch.model_transformer import PTModelTransformer +from tests.torch.helpers import TwoConvTestModel +from tests.torch.helpers import commands_are_equal +from tests.torch.nncf_network.helpers import AVAILABLE_TARGET_TYPES +from tests.torch.nncf_network.helpers import InsertionCommandBuilder + +TARGET_TYPE_VS_TARGET_TYPE_DICT_FOR_NOT_REPLACED_MODULES = { + TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK, + TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, + TargetType.OPERATION_WITH_WEIGHTS: TargetType.OPERATOR_PRE_HOOK, + TargetType.OPERATOR_PRE_HOOK: TargetType.OPERATOR_PRE_HOOK, + TargetType.OPERATOR_POST_HOOK: TargetType.OPERATOR_POST_HOOK, +} + + +@pytest.fixture(name="trace_parameters", params=(True, False)) +def trace_parameters_fixture(request) -> bool: + return request.param + + +def _translate_target_types(trace_parameters, command): + """ + Translates target types in case trace_parameters is True + """ + if not trace_parameters: + return + if isinstance(command, PTInsertionCommand): + target_points = [command.target_point] + else: + target_points = command.target_points + + for target_point in target_points: + new_target_type = TARGET_TYPE_VS_TARGET_TYPE_DICT_FOR_NOT_REPLACED_MODULES[target_point.type] + target_point._target_type = new_target_type + target_point.target_type = new_target_type + + +@pytest.mark.parametrize("target_type", AVAILABLE_TARGET_TYPES) +@pytest.mark.parametrize("command_builder", InsertionCommandBuilder.get_command_builders()) +def test_get_applied_modification_commands(command_builder, target_type, trace_parameters): + command = command_builder(target_type, TransformationPriority.DEFAULT_PRIORITY, trace_parameters=trace_parameters) + if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: + pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") + + model = TwoConvTestModel() + nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters) + model_transformer = PTModelTransformer(nncf_model) + + layout = PTTransformationLayout() + layout.register(command) + model_transformer.transform(layout) + + applied_commands = nncf_model.nncf.get_applied_transformation_layout() + + assert len(applied_commands.transformations) == 1 + applied_command = applied_commands.transformations[0] + _translate_target_types(trace_parameters, command) + assert commands_are_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) + + +@pytest.mark.parametrize("target_type", AVAILABLE_TARGET_TYPES) +@pytest.mark.parametrize("command_builder,command_type", InsertionCommandBuilder.get_command_builders_with_types()) +def test_priority_of_get_applied_modification_commands(command_builder, target_type, command_type, trace_parameters): + layout = PTTransformationLayout() + commands = dict() + for priority in (0, 3, 2, 4, 1): + if command_type is PTSharedFnInsertionCommand: + command = command_builder( + target_type, priority, op_unique_name=f"UNIQUE_NAME_{priority}", trace_parameters=trace_parameters + ) + else: + command = command_builder(target_type, priority, trace_parameters=trace_parameters) + layout.register(command) + commands[priority] = command + else: + if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: + pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") + + model = TwoConvTestModel() + nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters) + model_tranformer = PTModelTransformer(nncf_model) + + model_tranformer.transform(layout) + + applied_commands = nncf_model.nncf.get_applied_transformation_layout() + assert len(applied_commands.transformations) == len(commands) + for applied_command in applied_commands.transformations: + command = commands[applied_command.priority] + _translate_target_types(trace_parameters, command) + assert commands_are_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) + + +def test_all_possible_combinations_of_commands_for_get_applied_commands(trace_parameters): + dummy_state = "DummyState" + commands = InsertionCommandBuilder.get_all_available_commands( + dummy_state, skip_model_transformer_unsupported=True, trace_parameters=trace_parameters + ) + + model = TwoConvTestModel() + nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters) + model_tranformer = PTModelTransformer(nncf_model) + + model_tranformer.transform(commands) + + applied_commands = nncf_model.nncf.get_applied_transformation_layout() + assert len(applied_commands.transformations) == len(commands.transformations) + for command in commands.transformations: + _translate_target_types(trace_parameters, command) + eq_commands = ( + commands_are_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) + for applied_command in applied_commands.transformations + ) + if sum(map(int, eq_commands)) != 1: + raise RuntimeError(f"Command {command} has no pair in recovered commands") + + +@pytest.mark.parametrize("target_type", (TargetType.OPERATION_WITH_WEIGHTS, TargetType.OPERATOR_PRE_HOOK)) +def test_get_applied_modification_commands_broken_call_hook(target_type, trace_parameters): + model = TwoConvTestModel() + nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters) + model_tranformer = PTModelTransformer(nncf_model) + + layout = PTTransformationLayout() + command = InsertionCommandBuilder.create_pt_shared_fn_insertion_command( + target_type=target_type, + priority=0, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP, + trace_parameters=trace_parameters, + ) + layout.register(command) + model_tranformer.transform(layout) + + nncf_model.nncf.external_op.clear() + with pytest.raises(AssertionError): + nncf_model.nncf.get_applied_transformation_layout() diff --git a/tests/torch/nncf_network/test_hook_handlers.py b/tests/torch/nncf_network/test_hook_handlers.py new file mode 100644 index 00000000000..06c22d24a84 --- /dev/null +++ b/tests/torch/nncf_network/test_hook_handlers.py @@ -0,0 +1,119 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Tuple + +import pytest +import torch + +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.hook_handle import HookHandle +from nncf.torch.dynamic_graph.io_handling import ExampleInputInfo +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.nncf_network import PTInsertionPoint +from tests.torch.helpers import HookChecker +from tests.torch.nncf_network.helpers import SimplestModel + + +@pytest.mark.parametrize( + "target_type, target_node_name, input_port_id", + [ + (TargetType.OPERATOR_PRE_HOOK, "/nncf_model_output_0", 0), + (TargetType.OPERATOR_POST_HOOK, "/nncf_model_input_0", 0), + (TargetType.PRE_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), + (TargetType.POST_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), + ], +) +class TestHookHandles: + class TestHook(torch.nn.Module): + def __init__(self): + super().__init__() + self._p = torch.nn.Parameter(torch.zeros((1,))) + + def forward(self, x): + return x + self._p + + @staticmethod + def _prepare_hook_handles_test( + target_type: TargetType, target_node_name: str, input_port_id: int + ) -> Tuple[NNCFNetwork, PTInsertionPoint, Callable[[List[HookHandle]], None]]: + model = SimplestModel() + example_input = torch.ones(SimplestModel.INPUT_SIZE) + input_info = ExampleInputInfo.from_example_input(example_input) + nncf_model = NNCFNetwork(model, input_info) + + node_name_vs_address = nncf_model.nncf.get_node_to_op_address_mapping() + ip = PTInsertionPoint(target_type, node_name_vs_address[target_node_name], input_port_id=input_port_id) + + checker = HookChecker(nncf_model, "conv") + + def _check(ref_hooks_): + checker.clear() + checker.add_ref(ref_hooks_, target_type, target_node_name, input_port_id) + checker.check_with_reference() + + return nncf_model, ip, _check + + def test_temporary_insert_at_point_by_hook_group_name( + self, target_type: TargetType, target_node_name: str, input_port_id: int + ): + nncf_model, ip, _check = self._prepare_hook_handles_test(target_type, target_node_name, input_port_id) + permanent_hook = self.TestHook() + TEMPORARY_HOOK_GROUP_NAME = "tmp" + # Make temporary hook a ref to the permanent hook + # to check tmp hooks are not removed by their id() + temporary_hook = permanent_hook + nncf_model.nncf.insert_at_point(ip, permanent_hook) + ref_hooks = [permanent_hook] + _check(ref_hooks) + + for _ in range(2): + temporary_hook = self.TestHook() + nncf_model.nncf.insert_at_point(ip, temporary_hook, TEMPORARY_HOOK_GROUP_NAME) + ref_hooks.append(temporary_hook) + _check(ref_hooks) + + nncf_model.nncf.insert_at_point(ip, permanent_hook) + ref_hooks.append(permanent_hook) + _check(ref_hooks) + + nncf_model.nncf.remove_hooks_group(TEMPORARY_HOOK_GROUP_NAME) + del ref_hooks[-2] + _check(ref_hooks) + assert not nncf_model.nncf._groups_vs_hooks_handlers[TEMPORARY_HOOK_GROUP_NAME] + + def test_insert_at_point_hook_handles(self, target_type: TargetType, target_node_name: str, input_port_id: int): + nncf_model, ip, _check = self._prepare_hook_handles_test(target_type, target_node_name, input_port_id) + permanent_hook = self.TestHook() + # Make temporary hook a ref to the permanent hook + # to check tmp hooks are not removed by their id() + temporary_hook = permanent_hook + tmp_hh = [] + nncf_model.nncf.insert_at_point(ip, permanent_hook) + + ref_hooks = [permanent_hook] + _check(ref_hooks) + + for _ in range(2): + temporary_hook = self.TestHook() + tmp_hh.append(nncf_model.nncf.insert_at_point(ip, temporary_hook)) + ref_hooks.append(temporary_hook) + _check(ref_hooks) + + nncf_model.nncf.insert_at_point(ip, permanent_hook) + ref_hooks.append(permanent_hook) + _check(ref_hooks) + + for hh in tmp_hh: + hh.remove() + + del ref_hooks[-2] + _check(ref_hooks) diff --git a/tests/torch/test_nncf_network.py b/tests/torch/nncf_network/test_nncf_network.py similarity index 80% rename from tests/torch/test_nncf_network.py rename to tests/torch/nncf_network/test_nncf_network.py index 331d8c8e951..5fff987ea44 100644 --- a/tests/torch/test_nncf_network.py +++ b/tests/torch/nncf_network/test_nncf_network.py @@ -14,7 +14,7 @@ from abc import ABCMeta from abc import abstractmethod from copy import deepcopy -from typing import Callable, List, Tuple, Type +from typing import Callable, Type import pytest import torch @@ -27,8 +27,6 @@ from nncf.common.graph import NNCFNode from nncf.common.graph.operator_metatypes import UnknownMetatype from nncf.common.graph.transformations.commands import TargetType -from nncf.common.graph.transformations.commands import TransformationPriority -from nncf.common.hook_handle import HookHandle from nncf.torch import register_module from nncf.torch.dynamic_graph.io_handling import ExampleInputInfo from nncf.torch.dynamic_graph.io_handling import FillerInputElement @@ -41,12 +39,9 @@ from nncf.torch.graph.operator_metatypes import PTConv2dMetatype from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType -from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand -from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.layer_utils import _NNCFModuleMixin from nncf.torch.layers import NNCFConv2d from nncf.torch.model_creation import wrap_model -from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.nncf_network import PTInsertionPoint @@ -54,12 +49,11 @@ from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME from tests.torch.composite.test_sparsity_quantization import get_basic_sparsity_plus_quantization_config from tests.torch.helpers import BasicConvTestModel -from tests.torch.helpers import HookChecker from tests.torch.helpers import TwoConvTestModel -from tests.torch.helpers import are_commands_equal from tests.torch.helpers import check_correct_nncf_modules_replacement from tests.torch.helpers import create_compressed_model_and_algo_for_test from tests.torch.helpers import register_bn_adaptation_init_args +from tests.torch.nncf_network.helpers import SimplestModel from tests.torch.test_models.synthetic import ManyNonEvalModules @@ -618,17 +612,6 @@ def test_can_work_with_sequential_models(): _ = model.nncf.get_clean_shallow_copy() -class SimplestModel(torch.nn.Module): - INPUT_SIZE = [1, 1, 32, 32] - - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(1, 1, 1) - - def forward(self, x): - return self.conv(x) - - @pytest.fixture(name="simple_net") def simple_net_(): model = NNCFNetwork(SimplestModel(), FillerInputInfo([FillerInputElement(SimplestModel.INPUT_SIZE)])) @@ -933,179 +916,3 @@ def test_insert_hook_after_parameter(): assert hook.forward_calls_counter == 1 assert torch.sum(result.nonzero()) > 0 assert torch.sum(result_with_hook.nonzero()) == 0 - - -@pytest.mark.parametrize( - "target_type, target_node_name, input_port_id", - [ - (TargetType.OPERATOR_PRE_HOOK, "/nncf_model_output_0", 0), - (TargetType.OPERATOR_POST_HOOK, "/nncf_model_input_0", 0), - (TargetType.PRE_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), - (TargetType.POST_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), - ], -) -class TestHookHandles: - class TestHook(torch.nn.Module): - def __init__(self): - super().__init__() - self._p = torch.nn.Parameter(torch.zeros((1,))) - - def forward(self, x): - return x + self._p - - @staticmethod - def _prepare_hook_handles_test( - target_type: TargetType, target_node_name: str, input_port_id: int - ) -> Tuple[NNCFNetwork, PTInsertionPoint, Callable[[List[HookHandle]], None]]: - model = SimplestModel() - example_input = torch.ones(SimplestModel.INPUT_SIZE) - input_info = ExampleInputInfo.from_example_input(example_input) - nncf_model = NNCFNetwork(model, input_info) - - node_name_vs_address = nncf_model.nncf.get_node_to_op_address_mapping() - ip = PTInsertionPoint(target_type, node_name_vs_address[target_node_name], input_port_id=input_port_id) - - checker = HookChecker(nncf_model, "conv") - - def _check(ref_hooks_): - checker.clear() - checker.add_ref(ref_hooks_, target_type, target_node_name, input_port_id) - checker.check_with_reference() - - return nncf_model, ip, _check - - def test_temporary_insert_at_point_by_hook_group_name( - self, target_type: TargetType, target_node_name: str, input_port_id: int - ): - nncf_model, ip, _check = self._prepare_hook_handles_test(target_type, target_node_name, input_port_id) - permanent_hook = self.TestHook() - TEMPORARY_HOOK_GROUP_NAME = "tmp" - # Make temporary hook a ref to the permanent hook - # to check tmp hooks are not removed by their id() - temporary_hook = permanent_hook - nncf_model.nncf.insert_at_point(ip, permanent_hook) - ref_hooks = [permanent_hook] - _check(ref_hooks) - - for _ in range(2): - temporary_hook = self.TestHook() - nncf_model.nncf.insert_at_point(ip, temporary_hook, TEMPORARY_HOOK_GROUP_NAME) - ref_hooks.append(temporary_hook) - _check(ref_hooks) - - nncf_model.nncf.insert_at_point(ip, permanent_hook) - ref_hooks.append(permanent_hook) - _check(ref_hooks) - - nncf_model.nncf.remove_hooks_group(TEMPORARY_HOOK_GROUP_NAME) - del ref_hooks[-2] - _check(ref_hooks) - assert not nncf_model.nncf._groups_vs_hooks_handlers[TEMPORARY_HOOK_GROUP_NAME] - - def test_insert_at_point_hook_handles(self, target_type: TargetType, target_node_name: str, input_port_id: int): - nncf_model, ip, _check = self._prepare_hook_handles_test(target_type, target_node_name, input_port_id) - permanent_hook = self.TestHook() - # Make temporary hook a ref to the permanent hook - # to check tmp hooks are not removed by their id() - temporary_hook = permanent_hook - tmp_hh = [] - nncf_model.nncf.insert_at_point(ip, permanent_hook) - - ref_hooks = [permanent_hook] - _check(ref_hooks) - - for _ in range(2): - temporary_hook = self.TestHook() - tmp_hh.append(nncf_model.nncf.insert_at_point(ip, temporary_hook)) - ref_hooks.append(temporary_hook) - _check(ref_hooks) - - nncf_model.nncf.insert_at_point(ip, permanent_hook) - ref_hooks.append(permanent_hook) - _check(ref_hooks) - - for hh in tmp_hh: - hh.remove() - - del ref_hooks[-2] - _check(ref_hooks) - - -@pytest.mark.parametrize("target_type", TwoConvTestModel.AVAILABLE_TARGET_TYPES) -@pytest.mark.parametrize("command_builder", TwoConvTestModel.get_command_builders()) -def test_get_applied_modification_commands(command_builder, target_type): - command = command_builder(target_type, TransformationPriority.DEFAULT_PRIORITY) - if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ - TargetType.PRE_LAYER_OPERATION, - TargetType.POST_LAYER_OPERATION, - ]: - pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") - - model = TwoConvTestModel() - nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) - model_tranformer = PTModelTransformer(nncf_model) - - layout = PTTransformationLayout() - layout.register(command) - model_tranformer.transform(layout) - - applied_commands = nncf_model.nncf.get_applied_transformation_layout() - - assert len(applied_commands.transformations) == 1 - applied_command = applied_commands.transformations[0] - are_commands_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) - - -@pytest.mark.parametrize("target_type", TwoConvTestModel.AVAILABLE_TARGET_TYPES) -@pytest.mark.parametrize( - "command_builder,command_type", tuple(zip(TwoConvTestModel.get_command_builders(), TwoConvTestModel.COMMAND_TYPES)) -) -def test_priority_of_get_applied_modification_commands(command_builder, target_type, command_type): - layout = PTTransformationLayout() - commands = dict() - for priority in (0, 3, 2, 4, 1): - if command_type is PTSharedFnInsertionCommand: - command = command_builder(target_type, priority, op_unique_name=f"UNIQUE_NAME_{priority}") - else: - command = command_builder(target_type, priority) - layout.register(command) - commands[priority] = command - else: - if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ - TargetType.PRE_LAYER_OPERATION, - TargetType.POST_LAYER_OPERATION, - ]: - pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") - - model = TwoConvTestModel() - nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) - model_tranformer = PTModelTransformer(nncf_model) - - model_tranformer.transform(layout) - - applied_commands = nncf_model.nncf.get_applied_transformation_layout() - assert len(applied_commands.transformations) == len(commands) - for applied_command in applied_commands.transformations: - command = commands[applied_command.priority] - are_commands_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) - - -def test_all_possible_combinations_of_commands_for_get_applied_commands(): - dummy_state = "DummyState" - commands = TwoConvTestModel.get_all_available_commands(dummy_state, skip_model_transformer_unsupported=True) - - model = TwoConvTestModel() - nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) - model_tranformer = PTModelTransformer(nncf_model) - - model_tranformer.transform(commands) - - applied_commands = nncf_model.nncf.get_applied_transformation_layout() - assert len(applied_commands.transformations) == len(commands.transformations) - for command in commands.transformations: - eq_commands = ( - are_commands_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) - for applied_command in applied_commands.transformations - ) - if sum(map(int, eq_commands)) != 1: - raise RuntimeError(f"Command {command} has no pair in recovered commands") diff --git a/tests/torch/test_api_behavior.py b/tests/torch/test_api_behavior.py index 9e18479e94e..565eaa7e86b 100644 --- a/tests/torch/test_api_behavior.py +++ b/tests/torch/test_api_behavior.py @@ -24,7 +24,7 @@ from tests.torch.helpers import OnesDatasetMock from tests.torch.helpers import TwoConvTestModel from tests.torch.helpers import create_compressed_model_and_algo_for_test -from tests.torch.test_nncf_network import SimplestModel +from tests.torch.nncf_network.helpers import SimplestModel INPUT_SAMPLE_SIZE = [1, 1, 4, 4] CONFIG_WITH_ALL_INIT_TYPES = { From 8a8645eab501d64d34545b4b627ef92b15ddf29e Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 12 Apr 2024 14:42:11 +0200 Subject: [PATCH 4/9] Shared operations case handling --- nncf/torch/graph/graph.py | 20 ++- nncf/torch/nncf_network.py | 8 +- tests/torch/helpers.py | 32 +++++ tests/torch/nncf_network/helpers.py | 122 ++++++++---------- .../test_get_applied_modifications.py | 102 +++++++++------ 5 files changed, 178 insertions(+), 106 deletions(-) diff --git a/nncf/torch/graph/graph.py b/nncf/torch/graph/graph.py index 588866367c0..3a43d51d132 100644 --- a/nncf/torch/graph/graph.py +++ b/nncf/torch/graph/graph.py @@ -53,6 +53,12 @@ def get_input_shape_for_insertion_point(self, insertion_point: PTTargetPoint) -> return quantizer_input_shape def get_op_nodes_in_scope(self, scope: Scope) -> List[NNCFNode]: + """ + Returns all NNCFNodes inside the given scope. + + :param scope: Given scope. + :return: All NNCFNodes inside the given scope. + """ matching_graph_op_nodes = [] for scope_str, nodes_in_module in self._layer_name_vs_shared_nodes.items(): module_scope = Scope.from_str(scope_str) @@ -60,10 +66,22 @@ def get_op_nodes_in_scope(self, scope: Scope) -> List[NNCFNode]: matching_graph_op_nodes.extend(nodes_in_module) return matching_graph_op_nodes - def get_op_node_in_scope(self, scope: Scope) -> List[NNCFNode]: + def get_op_nodes_with_scope(self, scope: Scope) -> List[NNCFNode]: + """ + Returns all NNCFNodes which share the given scope. + + :param scope: Given scope. + :return: All NNCFNodes which share the given scope. + """ return self._layer_name_vs_shared_nodes[str(scope)] def get_scope_by_node_name(self, node_name: NNCFNodeName) -> Scope: + """ + Returns a scope which corresponds to the given NNCF node name. + + :param node_name: Given node name. + :return: A scope which corresponds to the given NNCF node name. + """ matches = [] for node_id, scope_str in self._node_ids_vs_layer_names.items(): node = self.get_node_by_id(node_id) diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index 01f3201e1e7..81d24b778ee 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -820,8 +820,11 @@ def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str): (nncf_module.post_ops, TargetType.POST_LAYER_OPERATION), ): for priority, module in enumerate(ops.values()): - nodes_in_scope = nncf_graph.get_op_node_in_scope(module_scope) - assert len(nodes_in_scope) == 1 + nodes_in_scope = nncf_graph.get_op_nodes_with_scope(module_scope) + # Several NNCFNodes means that current NNCFModule was called + # several times. Only one insertion command is required to + # call hook as much times as the current NNCFModule, therefore + # we use first correspondent NNCFNode. nncf_node = nodes_in_scope[0] command_target_type = target_type if isinstance(module, UpdateWeight): @@ -857,6 +860,7 @@ def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str): input_port_id = None for priority, fn in enumerate(hooks.values()): target_node_names = nncf_node_names_map[op_address] + # Operation address is unique for each module call assert len(target_node_names) == 1 target_node_name = target_node_names[0] diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index 4c40f218163..c6d2c92625b 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -177,6 +177,14 @@ def nz_bias_num(self): class TwoConvTestModel(nn.Module): INPUT_SHAPE = [1, 1, 4, 4] + NNCF_CONV_NODES_NAMES = [ + "TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0", + "TwoConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0", + ] + CONV_NODES_NAMES = [ + "TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0", + "TwoConvTestModel/Sequential[features]/Sequential[1]/Conv2d[0]/conv2d_0", + ] def __init__(self): super().__init__() @@ -205,6 +213,30 @@ def nz_bias_num(self): return 2 +class TwoSharedConvTestModel(nn.Module): + INPUT_SHAPE = [1, 1, 4, 4] + NNCF_CONV_NODES_NAMES = [ + "TwoSharedConvTestModel/NNCFConv2d[conv1]/conv2d_0", + "TwoSharedConvTestModel/NNCFConv2d[conv2]/conv2d_0", + ] + CONV_NODES_NAMES = [ + "TwoSharedConvTestModel/Conv2d[conv1]/conv2d_0", + "TwoSharedConvTestModel/Conv2d[conv2]/conv2d_0", + ] + + def __init__(self): + super().__init__() + self.features = [] + self.conv1 = create_conv(1, 1, 1, -1, -2) + self.conv2 = create_conv(1, 1, 1, 0, 0) + + def forward(self, x): + for _ in range(2): + x = self.conv1(x) + x = self.conv2(x) + return x + + class LeNet(nn.Module): INPUT_SIZE = 1, 32, 32 diff --git a/tests/torch/nncf_network/helpers.py b/tests/torch/nncf_network/helpers.py index a16dba2b263..e690dc75dd7 100644 --- a/tests/torch/nncf_network/helpers.py +++ b/tests/torch/nncf_network/helpers.py @@ -11,6 +11,7 @@ import functools import itertools +from typing import Optional, Type import torch @@ -48,50 +49,46 @@ def forward(self, x): class InsertionCommandBuilder: """ Contains methods which allows to build all possible commands - for the TwoConvTestModel + for the given torch.nn.Module. Target module should have + NNCF_CONV_NODES_NAMES and CONV_NODES_NAMES with names of + target model convolutions and names of nncf-wrapped target model convolutions """ - NNCF_CONV_NODES_NAMES = [ - "TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0", - "TwoConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0", - ] - CONV_NODES_NAMES = [ - "TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0", - "TwoConvTestModel/Sequential[features]/Sequential[1]/Conv2d[0]/conv2d_0", - ] + def __init__(self, model_cls: Type[torch.nn.Module]): + self.model_cls = model_cls - TRACE_VS_NODE_NAMES = {True: CONV_NODES_NAMES, False: NNCF_CONV_NODES_NAMES} + TRACE_VS_NODE_NAMES = {True: "CONV_NODES_NAMES", False: "NNCF_CONV_NODES_NAMES"} - @classmethod def create_pt_insertion_command( - cls, + self, target_type: TargetType, priority: TransformationPriority, trace_parameters: bool, - fn=None, + fn: Optional[torch.nn.Module] = None, group: str = "default_group", + op_unique_name: Optional[str] = None, ): + attr_name = self.TRACE_VS_NODE_NAMES[trace_parameters] target_point = PTTargetPoint( - target_type=target_type, target_node_name=cls.TRACE_VS_NODE_NAMES[trace_parameters][0], input_port_id=0 + target_type=target_type, target_node_name=getattr(self.model_cls, attr_name)[0], input_port_id=0 ) if fn is None: fn = DummyOpWithState("DUMMY_STATE") return PTInsertionCommand(point=target_point, fn=fn, priority=priority, hooks_group_name=group) - @classmethod def create_pt_shared_fn_insertion_command( - cls, + self, target_type: TargetType, priority: TransformationPriority, trace_parameters: bool, compression_module_type: ExtraCompressionModuleType, - fn=None, + fn: Optional[torch.nn.Module] = None, group: str = "default_group", op_unique_name: str = "UNIQUE_NAME", ): target_points = [] - - for node_name in cls.TRACE_VS_NODE_NAMES[trace_parameters]: + attr_name = self.TRACE_VS_NODE_NAMES[trace_parameters] + for node_name in getattr(self.model_cls, attr_name): target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0)) if fn is None: fn = DummyOpWithState("DUMMY_STATE") @@ -104,30 +101,27 @@ def create_pt_shared_fn_insertion_command( hooks_group_name=group, ) - @staticmethod - def get_command_builders(): + def get_command_builders(self): + """ + Get all command builders available in a tuple. + """ return ( - InsertionCommandBuilder.create_pt_insertion_command, + self.create_pt_insertion_command, functools.partial( - InsertionCommandBuilder.create_pt_shared_fn_insertion_command, + self.create_pt_shared_fn_insertion_command, compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP, ), functools.partial( - InsertionCommandBuilder.create_pt_shared_fn_insertion_command, + self.create_pt_shared_fn_insertion_command, compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, ), ) - @classmethod - def get_command_builders_with_types(cls): - return tuple(zip(cls.get_command_builders(), cls.COMMAND_TYPES)) - - COMMAND_TYPES = [PTInsertionCommand, PTSharedFnInsertionCommand, PTSharedFnInsertionCommand] + COMMAND_CLASSES = [PTInsertionCommand, PTSharedFnInsertionCommand, PTSharedFnInsertionCommand] PRIORITIES = (TransformationPriority.QUANTIZATION_PRIORITY, TransformationPriority.QUANTIZATION_PRIORITY.value + 1) - @classmethod def get_all_available_commands( - cls, dummy_op_state, trace_parameters, skip_model_transformer_unsupported=False + self, dummy_op_state, trace_parameters, skip_model_transformer_unsupported=False ) -> TransformationLayout: """ Returns all possible commands to insert: @@ -136,27 +130,27 @@ def get_all_available_commands( layout = TransformationLayout() for idx, (target_type, (command_builder, command_type), priority) in enumerate( itertools.product( - AVAILABLE_TARGET_TYPES, zip(cls.get_command_builders(), cls.COMMAND_TYPES), cls.PRIORITIES + AVAILABLE_TARGET_TYPES, zip(self.get_command_builders(), self.COMMAND_CLASSES), self.PRIORITIES ) ): - if command_type is PTSharedFnInsertionCommand: - if skip_model_transformer_unsupported and target_type in [ + if ( + skip_model_transformer_unsupported + and command_type is PTSharedFnInsertionCommand + and target_type + in [ TargetType.PRE_LAYER_OPERATION, TargetType.POST_LAYER_OPERATION, - ]: - continue - command = cls._create_command( - command_builder, - target_type, - priority, - dummy_op_state, - op_unique_name=f"UNIQUE_NAME_{idx}", - trace_parameters=trace_parameters, - ) - else: - command = cls._create_command( - command_builder, target_type, priority, dummy_op_state, trace_parameters=trace_parameters - ) + ] + ): + continue + command = self._create_command( + command_builder, + target_type, + priority, + dummy_op_state, + op_unique_name=f"UNIQUE_NAME_{idx}", + trace_parameters=trace_parameters, + ) layout.register(command) return layout @@ -168,27 +162,25 @@ def _create_command( priority, dummy_op_state, trace_parameters, - op_unique_name=None, + op_unique_name, ): - group_name = "CUSTOM_HOOKS_GROUP_NAME" - + """ + Creates command with specified parameters and dummy op. + """ + # Register dummy op name in the COMPRESSION_MODULES if DummyOpWithState.__name__ not in COMPRESSION_MODULES.registry_dict: registered_dummy_op_cls = COMPRESSION_MODULES.register()(DummyOpWithState) else: registered_dummy_op_cls = DummyOpWithState dummy_op = registered_dummy_op_cls(dummy_op_state) - if op_unique_name is None: - command = command_builder( - target_type, priority, fn=dummy_op, group=group_name, trace_parameters=trace_parameters - ) - else: - command = command_builder( - target_type, - priority, - fn=dummy_op, - group=group_name, - op_unique_name=op_unique_name, - trace_parameters=trace_parameters, - ) - return command + # Build the command + group_name = "CUSTOM_HOOKS_GROUP_NAME" + return command_builder( + target_type, + priority, + fn=dummy_op, + group=group_name, + op_unique_name=op_unique_name, + trace_parameters=trace_parameters, + ) diff --git a/tests/torch/nncf_network/test_get_applied_modifications.py b/tests/torch/nncf_network/test_get_applied_modifications.py index d19840821c1..7971ddab22a 100644 --- a/tests/torch/nncf_network/test_get_applied_modifications.py +++ b/tests/torch/nncf_network/test_get_applied_modifications.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools + import pytest import torch @@ -21,6 +23,7 @@ from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.model_transformer import PTModelTransformer from tests.torch.helpers import TwoConvTestModel +from tests.torch.helpers import TwoSharedConvTestModel from tests.torch.helpers import commands_are_equal from tests.torch.nncf_network.helpers import AVAILABLE_TARGET_TYPES from tests.torch.nncf_network.helpers import InsertionCommandBuilder @@ -39,6 +42,38 @@ def trace_parameters_fixture(request) -> bool: return request.param +MODELS_TO_TEST = (TwoConvTestModel, TwoSharedConvTestModel) + + +def _get_trace_params_target_types_command_builders_and_models_cls(): + retval = [] + for ( + trace_parameters, + model_cls, + target_type, + ) in itertools.product( + (True, False), + MODELS_TO_TEST, + AVAILABLE_TARGET_TYPES, + ): + for command_builder, command_cls in zip( + InsertionCommandBuilder(model_cls).get_command_builders(), InsertionCommandBuilder.COMMAND_CLASSES + ): + if ( + not trace_parameters + and command_cls is PTSharedFnInsertionCommand + and target_type + in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ] + ): + print(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") + continue + retval.append((trace_parameters, model_cls, target_type, command_builder)) + return retval + + def _translate_target_types(trace_parameters, command): """ Translates target types in case trace_parameters is True @@ -56,21 +91,17 @@ def _translate_target_types(trace_parameters, command): target_point.target_type = new_target_type -@pytest.mark.parametrize("target_type", AVAILABLE_TARGET_TYPES) -@pytest.mark.parametrize("command_builder", InsertionCommandBuilder.get_command_builders()) -def test_get_applied_modification_commands(command_builder, target_type, trace_parameters): - command = command_builder(target_type, TransformationPriority.DEFAULT_PRIORITY, trace_parameters=trace_parameters) - if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ - TargetType.PRE_LAYER_OPERATION, - TargetType.POST_LAYER_OPERATION, - ]: - pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") - - model = TwoConvTestModel() - nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters) +@pytest.mark.parametrize( + "trace_parameters_p,model_cls,target_type,command_builder", + _get_trace_params_target_types_command_builders_and_models_cls(), +) +def test_get_applied_modification_commands(model_cls, command_builder, target_type, trace_parameters_p): + model = model_cls() + nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters_p) model_transformer = PTModelTransformer(nncf_model) layout = PTTransformationLayout() + command = command_builder(target_type, TransformationPriority.DEFAULT_PRIORITY, trace_parameters=trace_parameters_p) layout.register(command) model_transformer.transform(layout) @@ -78,33 +109,26 @@ def test_get_applied_modification_commands(command_builder, target_type, trace_p assert len(applied_commands.transformations) == 1 applied_command = applied_commands.transformations[0] - _translate_target_types(trace_parameters, command) + _translate_target_types(trace_parameters_p, command) assert commands_are_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) -@pytest.mark.parametrize("target_type", AVAILABLE_TARGET_TYPES) -@pytest.mark.parametrize("command_builder,command_type", InsertionCommandBuilder.get_command_builders_with_types()) -def test_priority_of_get_applied_modification_commands(command_builder, target_type, command_type, trace_parameters): +@pytest.mark.parametrize( + "trace_parameters_p,model_cls,target_type,command_builder", + _get_trace_params_target_types_command_builders_and_models_cls(), +) +def test_priority_of_get_applied_modification_commands(command_builder, model_cls, target_type, trace_parameters_p): layout = PTTransformationLayout() commands = dict() for priority in (0, 3, 2, 4, 1): - if command_type is PTSharedFnInsertionCommand: - command = command_builder( - target_type, priority, op_unique_name=f"UNIQUE_NAME_{priority}", trace_parameters=trace_parameters - ) - else: - command = command_builder(target_type, priority, trace_parameters=trace_parameters) + command = command_builder( + target_type, priority, op_unique_name=f"UNIQUE_NAME_{priority}", trace_parameters=trace_parameters_p + ) layout.register(command) commands[priority] = command - else: - if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ - TargetType.PRE_LAYER_OPERATION, - TargetType.POST_LAYER_OPERATION, - ]: - pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") - model = TwoConvTestModel() - nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters) + model = model_cls() + nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters_p) model_tranformer = PTModelTransformer(nncf_model) model_tranformer.transform(layout) @@ -113,17 +137,18 @@ def test_priority_of_get_applied_modification_commands(command_builder, target_t assert len(applied_commands.transformations) == len(commands) for applied_command in applied_commands.transformations: command = commands[applied_command.priority] - _translate_target_types(trace_parameters, command) + _translate_target_types(trace_parameters_p, command) assert commands_are_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) -def test_all_possible_combinations_of_commands_for_get_applied_commands(trace_parameters): +@pytest.mark.parametrize("model_cls", MODELS_TO_TEST) +def test_all_possible_combinations_of_commands_for_get_applied_commands(model_cls, trace_parameters): dummy_state = "DummyState" - commands = InsertionCommandBuilder.get_all_available_commands( - dummy_state, skip_model_transformer_unsupported=True, trace_parameters=trace_parameters + commands = InsertionCommandBuilder(model_cls).get_all_available_commands( + dummy_state, skip_model_transformer_unsupported=not trace_parameters, trace_parameters=trace_parameters ) - model = TwoConvTestModel() + model = model_cls() nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters) model_tranformer = PTModelTransformer(nncf_model) @@ -142,13 +167,14 @@ def test_all_possible_combinations_of_commands_for_get_applied_commands(trace_pa @pytest.mark.parametrize("target_type", (TargetType.OPERATION_WITH_WEIGHTS, TargetType.OPERATOR_PRE_HOOK)) -def test_get_applied_modification_commands_broken_call_hook(target_type, trace_parameters): - model = TwoConvTestModel() +@pytest.mark.parametrize("model_cls", MODELS_TO_TEST) +def test_get_applied_modification_commands_broken_call_hook(model_cls, target_type, trace_parameters): + model = model_cls() nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters) model_tranformer = PTModelTransformer(nncf_model) layout = PTTransformationLayout() - command = InsertionCommandBuilder.create_pt_shared_fn_insertion_command( + command = InsertionCommandBuilder(model_cls).create_pt_shared_fn_insertion_command( target_type=target_type, priority=0, compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP, From 46eb120d6d3a4fd0626fbac3b38c8bd1d52174e8 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Mon, 15 Apr 2024 15:10:22 +0200 Subject: [PATCH 5/9] Comments --- nncf/torch/graph/transformations/commands.py | 6 +- nncf/torch/nncf_network.py | 2 +- tests/torch/nncf_network/helpers.py | 52 +++++++----- .../test_get_applied_modifications.py | 85 ++++++++++--------- 4 files changed, 80 insertions(+), 65 deletions(-) diff --git a/nncf/torch/graph/transformations/commands.py b/nncf/torch/graph/transformations/commands.py index b2461277a5f..1dd2647d584 100644 --- a/nncf/torch/graph/transformations/commands.py +++ b/nncf/torch/graph/transformations/commands.py @@ -10,7 +10,7 @@ # limitations under the License. from enum import Enum -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Union import torch @@ -139,7 +139,7 @@ def __init__( self, point: PTTargetPoint, fn: Callable, - priority: TransformationPriority = TransformationPriority.DEFAULT_PRIORITY, + priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, hooks_group_name: str = DEFAULT_HOOKS_GROUP_NAME, ): super().__init__(TransformationType.INSERT, point) @@ -164,7 +164,7 @@ def __init__( fn: Callable, op_unique_name: str, compression_module_type: ExtraCompressionModuleType = ExtraCompressionModuleType.EXTERNAL_OP, - priority: TransformationPriority = TransformationPriority.DEFAULT_PRIORITY, + priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, hooks_group_name: str = DEFAULT_HOOKS_GROUP_NAME, ): super().__init__(TransformationType.INSERT, None) diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index 81d24b778ee..5cd45dfc191 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -788,7 +788,7 @@ def get_applied_transformation_layout(self) -> PTTransformationLayout: """ Collects all hooks applied to the NNCFNetwork, converts them to insertion commands and returns in PTTransformationLayout format. Default hooks group name is used in - recovered commands, so hooks group names specified diring the model modification + recovered commands, so hooks group names specified during the model modification become outdated. :return: Transformation layout with all commands applied to the NNCFNetwork. diff --git a/tests/torch/nncf_network/helpers.py b/tests/torch/nncf_network/helpers.py index e690dc75dd7..c51897cb3d7 100644 --- a/tests/torch/nncf_network/helpers.py +++ b/tests/torch/nncf_network/helpers.py @@ -22,8 +22,11 @@ from nncf.torch.graph.transformations.commands import PTInsertionCommand from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.commands import PTTransformationCommand from nncf.torch.layer_utils import COMPRESSION_MODULES from tests.torch.helpers import DummyOpWithState +from tests.torch.helpers import TwoConvTestModel +from tests.torch.helpers import TwoSharedConvTestModel class SimplestModel(torch.nn.Module): @@ -54,6 +57,8 @@ class InsertionCommandBuilder: target model convolutions and names of nncf-wrapped target model convolutions """ + AVAILABLE_MODELS = (TwoConvTestModel, TwoSharedConvTestModel) + def __init__(self, model_cls: Type[torch.nn.Module]): self.model_cls = model_cls @@ -103,21 +108,28 @@ def create_pt_shared_fn_insertion_command( def get_command_builders(self): """ - Get all command builders available in a tuple. + Get all command builders available and their types in a tuple of pairs. """ return ( - self.create_pt_insertion_command, - functools.partial( - self.create_pt_shared_fn_insertion_command, - compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP, + (self.create_pt_insertion_command, PTInsertionCommand), + ( + functools.partial( + self.create_pt_shared_fn_insertion_command, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP, + ), + PTSharedFnInsertionCommand, ), - functools.partial( - self.create_pt_shared_fn_insertion_command, - compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, + ( + functools.partial( + self.create_pt_shared_fn_insertion_command, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, + ), + PTSharedFnInsertionCommand, ), ) COMMAND_CLASSES = [PTInsertionCommand, PTSharedFnInsertionCommand, PTSharedFnInsertionCommand] + # Check priority as an enum member and as an int PRIORITIES = (TransformationPriority.QUANTIZATION_PRIORITY, TransformationPriority.QUANTIZATION_PRIORITY.value + 1) def get_all_available_commands( @@ -129,18 +141,10 @@ def get_all_available_commands( """ layout = TransformationLayout() for idx, (target_type, (command_builder, command_type), priority) in enumerate( - itertools.product( - AVAILABLE_TARGET_TYPES, zip(self.get_command_builders(), self.COMMAND_CLASSES), self.PRIORITIES - ) + itertools.product(AVAILABLE_TARGET_TYPES, self.get_command_builders(), self.PRIORITIES) ): - if ( - skip_model_transformer_unsupported - and command_type is PTSharedFnInsertionCommand - and target_type - in [ - TargetType.PRE_LAYER_OPERATION, - TargetType.POST_LAYER_OPERATION, - ] + if skip_model_transformer_unsupported and self.is_unsupported_by_transformer_command( + command_type, target_type ): continue command = self._create_command( @@ -155,6 +159,16 @@ def get_all_available_commands( layout.register(command) return layout + @staticmethod + def is_unsupported_by_transformer_command(command_type: PTTransformationCommand, target_type: TargetType) -> bool: + """ + Returns True if insertion parameters don't supported by the PTModelTransformer otherwise False. + """ + return command_type is PTSharedFnInsertionCommand and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ] + @staticmethod def _create_command( command_builder, diff --git a/tests/torch/nncf_network/test_get_applied_modifications.py b/tests/torch/nncf_network/test_get_applied_modifications.py index 7971ddab22a..c41e81c00f9 100644 --- a/tests/torch/nncf_network/test_get_applied_modifications.py +++ b/tests/torch/nncf_network/test_get_applied_modifications.py @@ -10,6 +10,7 @@ # limitations under the License. import itertools +from typing import Tuple, Type, Union import pytest import torch @@ -22,8 +23,6 @@ from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.model_transformer import PTModelTransformer -from tests.torch.helpers import TwoConvTestModel -from tests.torch.helpers import TwoSharedConvTestModel from tests.torch.helpers import commands_are_equal from tests.torch.nncf_network.helpers import AVAILABLE_TARGET_TYPES from tests.torch.nncf_network.helpers import InsertionCommandBuilder @@ -42,10 +41,12 @@ def trace_parameters_fixture(request) -> bool: return request.param -MODELS_TO_TEST = (TwoConvTestModel, TwoSharedConvTestModel) - - -def _get_trace_params_target_types_command_builders_and_models_cls(): +def _get_trace_params_target_types_command_builders_and_models_cls() -> ( + Tuple[bool, Type[torch.nn.Module], TargetType, callable] +): + """ + Returns list of all avaliable command builders + """ retval = [] for ( trace_parameters, @@ -53,20 +54,12 @@ def _get_trace_params_target_types_command_builders_and_models_cls(): target_type, ) in itertools.product( (True, False), - MODELS_TO_TEST, + InsertionCommandBuilder.AVAILABLE_MODELS, AVAILABLE_TARGET_TYPES, ): - for command_builder, command_cls in zip( - InsertionCommandBuilder(model_cls).get_command_builders(), InsertionCommandBuilder.COMMAND_CLASSES - ): - if ( - not trace_parameters - and command_cls is PTSharedFnInsertionCommand - and target_type - in [ - TargetType.PRE_LAYER_OPERATION, - TargetType.POST_LAYER_OPERATION, - ] + for command_builder, command_cls in InsertionCommandBuilder(model_cls).get_command_builders(): + if not trace_parameters and InsertionCommandBuilder.is_unsupported_by_transformer_command( + command_cls, target_type ): print(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") continue @@ -74,7 +67,7 @@ def _get_trace_params_target_types_command_builders_and_models_cls(): return retval -def _translate_target_types(trace_parameters, command): +def _translate_target_types(trace_parameters: bool, command: Union[PTInsertionCommand, PTSharedFnInsertionCommand]): """ Translates target types in case trace_parameters is True """ @@ -92,16 +85,18 @@ def _translate_target_types(trace_parameters, command): @pytest.mark.parametrize( - "trace_parameters_p,model_cls,target_type,command_builder", + "trace_parameters,model_cls,target_type,command_builder", _get_trace_params_target_types_command_builders_and_models_cls(), ) -def test_get_applied_modification_commands(model_cls, command_builder, target_type, trace_parameters_p): +def test_get_applied_modification_commands( + model_cls: Type[torch.nn.Module], command_builder: callable, target_type: TargetType, trace_parameters: bool +): model = model_cls() - nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters_p) + nncf_model = wrap_model(model, torch.zeros(model_cls.INPUT_SHAPE), trace_parameters=trace_parameters) model_transformer = PTModelTransformer(nncf_model) layout = PTTransformationLayout() - command = command_builder(target_type, TransformationPriority.DEFAULT_PRIORITY, trace_parameters=trace_parameters_p) + command = command_builder(target_type, TransformationPriority.DEFAULT_PRIORITY, trace_parameters=trace_parameters) layout.register(command) model_transformer.transform(layout) @@ -109,50 +104,54 @@ def test_get_applied_modification_commands(model_cls, command_builder, target_ty assert len(applied_commands.transformations) == 1 applied_command = applied_commands.transformations[0] - _translate_target_types(trace_parameters_p, command) + _translate_target_types(trace_parameters, command) assert commands_are_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) @pytest.mark.parametrize( - "trace_parameters_p,model_cls,target_type,command_builder", + "trace_parameters,model_cls,target_type,command_builder", _get_trace_params_target_types_command_builders_and_models_cls(), ) -def test_priority_of_get_applied_modification_commands(command_builder, model_cls, target_type, trace_parameters_p): +def test_priority_of_get_applied_modification_commands( + command_builder: callable, model_cls: Type[torch.nn.Module], target_type: TargetType, trace_parameters: bool +): layout = PTTransformationLayout() commands = dict() - for priority in (0, 3, 2, 4, 1): + for priority in (0, 2, 1): command = command_builder( - target_type, priority, op_unique_name=f"UNIQUE_NAME_{priority}", trace_parameters=trace_parameters_p + target_type, priority, op_unique_name=f"UNIQUE_NAME_{priority}", trace_parameters=trace_parameters ) layout.register(command) commands[priority] = command model = model_cls() - nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters_p) - model_tranformer = PTModelTransformer(nncf_model) + nncf_model = wrap_model(model, torch.zeros(model_cls.INPUT_SHAPE), trace_parameters=trace_parameters) + model_transformer = PTModelTransformer(nncf_model) - model_tranformer.transform(layout) + model_transformer.transform(layout) applied_commands = nncf_model.nncf.get_applied_transformation_layout() assert len(applied_commands.transformations) == len(commands) for applied_command in applied_commands.transformations: command = commands[applied_command.priority] - _translate_target_types(trace_parameters_p, command) + _translate_target_types(trace_parameters, command) assert commands_are_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) -@pytest.mark.parametrize("model_cls", MODELS_TO_TEST) -def test_all_possible_combinations_of_commands_for_get_applied_commands(model_cls, trace_parameters): +@pytest.mark.parametrize("model_cls", InsertionCommandBuilder.AVAILABLE_MODELS) +def test_all_possible_combinations_of_commands_for_get_applied_commands( + model_cls: Type[torch.nn.Module], trace_parameters: bool +): dummy_state = "DummyState" commands = InsertionCommandBuilder(model_cls).get_all_available_commands( dummy_state, skip_model_transformer_unsupported=not trace_parameters, trace_parameters=trace_parameters ) model = model_cls() - nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters) - model_tranformer = PTModelTransformer(nncf_model) + nncf_model = wrap_model(model, torch.zeros(model_cls.INPUT_SHAPE), trace_parameters=trace_parameters) + model_transformer = PTModelTransformer(nncf_model) - model_tranformer.transform(commands) + model_transformer.transform(commands) applied_commands = nncf_model.nncf.get_applied_transformation_layout() assert len(applied_commands.transformations) == len(commands.transformations) @@ -167,11 +166,13 @@ def test_all_possible_combinations_of_commands_for_get_applied_commands(model_cl @pytest.mark.parametrize("target_type", (TargetType.OPERATION_WITH_WEIGHTS, TargetType.OPERATOR_PRE_HOOK)) -@pytest.mark.parametrize("model_cls", MODELS_TO_TEST) -def test_get_applied_modification_commands_broken_call_hook(model_cls, target_type, trace_parameters): +@pytest.mark.parametrize("model_cls", InsertionCommandBuilder.AVAILABLE_MODELS) +def test_get_applied_modification_commands_broken_call_hook( + model_cls: Type[torch.nn.Module], target_type: TargetType, trace_parameters: bool +): model = model_cls() - nncf_model = wrap_model(model, torch.zeros([1, 1, 4, 4]), trace_parameters=trace_parameters) - model_tranformer = PTModelTransformer(nncf_model) + nncf_model = wrap_model(model, torch.zeros(model_cls.INPUT_SHAPE), trace_parameters=trace_parameters) + model_transformer = PTModelTransformer(nncf_model) layout = PTTransformationLayout() command = InsertionCommandBuilder(model_cls).create_pt_shared_fn_insertion_command( @@ -181,7 +182,7 @@ def test_get_applied_modification_commands_broken_call_hook(model_cls, target_ty trace_parameters=trace_parameters, ) layout.register(command) - model_tranformer.transform(layout) + model_transformer.transform(layout) nncf_model.nncf.external_op.clear() with pytest.raises(AssertionError): From 9e62cbf2d2c8e429bdfeb61f32e351c65db64b73 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Mon, 15 Apr 2024 15:58:47 +0200 Subject: [PATCH 6/9] Input port id is corrected for the commands --- tests/torch/helpers.py | 22 ++-------------------- tests/torch/nncf_network/helpers.py | 20 ++++++++++++++++++-- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index c6d2c92625b..43b0b69c60b 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -41,7 +41,6 @@ from nncf.torch.dynamic_graph.scope import Scope from nncf.torch.graph.transformations.commands import PTInsertionCommand from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand -from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.initialization import PTInitializingDataLoader from nncf.torch.initialization import register_default_init_args from nncf.torch.layers import NNCF_MODULES_MAP @@ -287,21 +286,6 @@ def from_state(cls, state: str): return cls(state) -def target_points_are_equal(tp_left: PTTargetPoint, tp_right: PTTargetPoint) -> bool: - """ - Returns True if given target points are equal and False elsewhere. - - :param tp_left: The first target point. - :param tp_right: The second target point. - :return: True if given target points are equal and False elsewhere. - """ - if tp_left != tp_right: - return False - if tp_left.target_type == TargetType.OPERATOR_PRE_HOOK: - return tp_left.input_port_id == tp_right.input_port_id - return True - - def commands_are_equal( command_left: Union[PTInsertionCommand, PTSharedFnInsertionCommand], command_right: Union[PTInsertionCommand, PTSharedFnInsertionCommand], @@ -331,12 +315,10 @@ def commands_are_equal( return False if isinstance(command_right, PTInsertionCommand): - if not target_points_are_equal(command_left.target_point, command_right.target_point): + if command_left.target_point != command_right.target_point: return False elif isinstance(command_right, PTSharedFnInsertionCommand): - if not all( - target_points_are_equal(a, b) for a, b in zip(command_left.target_points, command_right.target_points) - ): + if not all(a == b for a, b in zip(command_left.target_points, command_right.target_points)): return False if ( command_right.target_points != command_left.target_points diff --git a/tests/torch/nncf_network/helpers.py b/tests/torch/nncf_network/helpers.py index c51897cb3d7..36a850e986d 100644 --- a/tests/torch/nncf_network/helpers.py +++ b/tests/torch/nncf_network/helpers.py @@ -64,6 +64,14 @@ def __init__(self, model_cls: Type[torch.nn.Module]): TRACE_VS_NODE_NAMES = {True: "CONV_NODES_NAMES", False: "NNCF_CONV_NODES_NAMES"} + @staticmethod + def get_input_port_id(target_type: TargetType, trace_parameters: bool) -> Optional[int]: + if target_type is TargetType.OPERATOR_PRE_HOOK: + return 0 + if trace_parameters and target_type in [TargetType.PRE_LAYER_OPERATION, TargetType.OPERATION_WITH_WEIGHTS]: + return 1 + return None + def create_pt_insertion_command( self, target_type: TargetType, @@ -75,7 +83,9 @@ def create_pt_insertion_command( ): attr_name = self.TRACE_VS_NODE_NAMES[trace_parameters] target_point = PTTargetPoint( - target_type=target_type, target_node_name=getattr(self.model_cls, attr_name)[0], input_port_id=0 + target_type=target_type, + target_node_name=getattr(self.model_cls, attr_name)[0], + input_port_id=self.get_input_port_id(target_type, trace_parameters), ) if fn is None: fn = DummyOpWithState("DUMMY_STATE") @@ -94,7 +104,13 @@ def create_pt_shared_fn_insertion_command( target_points = [] attr_name = self.TRACE_VS_NODE_NAMES[trace_parameters] for node_name in getattr(self.model_cls, attr_name): - target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0)) + target_points.append( + PTTargetPoint( + target_type=target_type, + target_node_name=node_name, + input_port_id=self.get_input_port_id(target_type, trace_parameters), + ) + ) if fn is None: fn = DummyOpWithState("DUMMY_STATE") return PTSharedFnInsertionCommand( From a6abd0a2558685029634ecfad11a8162ad19e86f Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 19 Apr 2024 13:21:02 +0200 Subject: [PATCH 7/9] get_applied_transformation_layout -> transformation_layout --- nncf/torch/nncf_network.py | 2 +- .../torch/nncf_network/test_get_applied_modifications.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index 5cd45dfc191..5fbbac4ad99 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -784,7 +784,7 @@ def _collect_eval_op_scopes(self, model: nn.Module, dummy_forward_fn: Callable) result.append(scope_in_model) return result - def get_applied_transformation_layout(self) -> PTTransformationLayout: + def transformation_layout(self) -> PTTransformationLayout: """ Collects all hooks applied to the NNCFNetwork, converts them to insertion commands and returns in PTTransformationLayout format. Default hooks group name is used in diff --git a/tests/torch/nncf_network/test_get_applied_modifications.py b/tests/torch/nncf_network/test_get_applied_modifications.py index c41e81c00f9..075625148eb 100644 --- a/tests/torch/nncf_network/test_get_applied_modifications.py +++ b/tests/torch/nncf_network/test_get_applied_modifications.py @@ -100,7 +100,7 @@ def test_get_applied_modification_commands( layout.register(command) model_transformer.transform(layout) - applied_commands = nncf_model.nncf.get_applied_transformation_layout() + applied_commands = nncf_model.nncf.transformation_layout() assert len(applied_commands.transformations) == 1 applied_command = applied_commands.transformations[0] @@ -130,7 +130,7 @@ def test_priority_of_get_applied_modification_commands( model_transformer.transform(layout) - applied_commands = nncf_model.nncf.get_applied_transformation_layout() + applied_commands = nncf_model.nncf.transformation_layout() assert len(applied_commands.transformations) == len(commands) for applied_command in applied_commands.transformations: command = commands[applied_command.priority] @@ -153,7 +153,7 @@ def test_all_possible_combinations_of_commands_for_get_applied_commands( model_transformer.transform(commands) - applied_commands = nncf_model.nncf.get_applied_transformation_layout() + applied_commands = nncf_model.nncf.transformation_layout() assert len(applied_commands.transformations) == len(commands.transformations) for command in commands.transformations: _translate_target_types(trace_parameters, command) @@ -186,4 +186,4 @@ def test_get_applied_modification_commands_broken_call_hook( nncf_model.nncf.external_op.clear() with pytest.raises(AssertionError): - nncf_model.nncf.get_applied_transformation_layout() + nncf_model.nncf.transformation_layout() From a49fe3407047389a388f9b4022551c11e9ad1b8e Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Mon, 22 Apr 2024 14:39:27 +0200 Subject: [PATCH 8/9] Minor --- tests/torch/nncf_network/helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/torch/nncf_network/helpers.py b/tests/torch/nncf_network/helpers.py index 36a850e986d..06805cd59b7 100644 --- a/tests/torch/nncf_network/helpers.py +++ b/tests/torch/nncf_network/helpers.py @@ -144,7 +144,6 @@ def get_command_builders(self): ), ) - COMMAND_CLASSES = [PTInsertionCommand, PTSharedFnInsertionCommand, PTSharedFnInsertionCommand] # Check priority as an enum member and as an int PRIORITIES = (TransformationPriority.QUANTIZATION_PRIORITY, TransformationPriority.QUANTIZATION_PRIORITY.value + 1) From 6937ca04dac3eff7b4a9087846264a6488e91e55 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Wed, 24 Apr 2024 18:09:51 +0200 Subject: [PATCH 9/9] Comments --- .../graph/transformations/command_creation.py | 26 ++++++++++++++++++- nncf/torch/nncf_network.py | 25 +----------------- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/nncf/torch/graph/transformations/command_creation.py b/nncf/torch/graph/transformations/command_creation.py index 6146803ae19..bb2bf59a122 100644 --- a/nncf/torch/graph/transformations/command_creation.py +++ b/nncf/torch/graph/transformations/command_creation.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import List, Optional, Union +import torch from torch import Tensor from nncf.common.graph.graph import NNCFNode @@ -82,3 +83,26 @@ def create_shared_quantizer_insertion_command( compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, priority=TransformationPriority.QUANTIZATION_PRIORITY, ) + + +def create_pt_insertion_command( + module: torch.nn.Module, + target_type: TargetType, + target_node_name: str, + priority: int, + input_port_id: Optional[int], +) -> PTInsertionCommand: + """ + Creates a PTInsertionCommand. + + :param module: Torch module to insert. + :param target_type: Insertion command target type. + :param target_name: Insertion command target name. + :param priority: Insertion command priority. + :param input_port_id: Insertion command input port id. + :return: A PTInsertionCommand + """ + target_point = PTTargetPoint( + target_type=target_type, target_node_name=target_node_name, input_port_id=input_port_id + ) + return PTInsertionCommand(point=target_point, fn=module, priority=priority) diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index 5fbbac4ad99..0ca8a14380f 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -67,9 +67,9 @@ from nncf.torch.graph.graph_builder import GraphConverter from nncf.torch.graph.operator_metatypes import OPERATORS_WITH_WEIGHTS_METATYPES from nncf.torch.graph.operator_metatypes import PTSplitMetatype +from nncf.torch.graph.transformations.command_creation import create_pt_insertion_command from nncf.torch.graph.transformations.commands import DEFAULT_HOOKS_GROUP_NAME from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType -from nncf.torch.graph.transformations.commands import PTInsertionCommand from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.graph.transformations.layout import PTTransformationLayout @@ -1270,26 +1270,3 @@ def compression_module_type_to_attr_name(compression_module_type: ExtraCompressi if compression_module_type == ExtraCompressionModuleType.EXTERNAL_OP: return EXTERNAL_OP_STORAGE_NAME raise nncf.ValidationError("Unknown extra module type") - - -def create_pt_insertion_command( - module: torch.nn.Module, - target_type: TargetType, - target_node_name: str, - priority: int, - input_port_id: Optional[int], -) -> PTInsertionCommand: - """ - Creates a PTInsertionCommand. - - :param module: Torch module to insert. - :param target_type: Insertion command target type. - :param target_name: Insertion command target name. - :param priority: Insertion command priority. - :param input_port_id: Insertion command input port id. - :return: A PTInsertionCommand - """ - target_point = PTTargetPoint( - target_type=target_type, target_node_name=target_node_name, input_port_id=input_port_id - ) - return PTInsertionCommand(point=target_point, fn=module, priority=priority)