From 9e62cbf2d2c8e429bdfeb61f32e351c65db64b73 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Mon, 15 Apr 2024 15:58:47 +0200 Subject: [PATCH] Input port id is corrected for the commands --- tests/torch/helpers.py | 22 ++-------------------- tests/torch/nncf_network/helpers.py | 20 ++++++++++++++++++-- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index c6d2c92625b..43b0b69c60b 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -41,7 +41,6 @@ from nncf.torch.dynamic_graph.scope import Scope from nncf.torch.graph.transformations.commands import PTInsertionCommand from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand -from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.initialization import PTInitializingDataLoader from nncf.torch.initialization import register_default_init_args from nncf.torch.layers import NNCF_MODULES_MAP @@ -287,21 +286,6 @@ def from_state(cls, state: str): return cls(state) -def target_points_are_equal(tp_left: PTTargetPoint, tp_right: PTTargetPoint) -> bool: - """ - Returns True if given target points are equal and False elsewhere. - - :param tp_left: The first target point. - :param tp_right: The second target point. - :return: True if given target points are equal and False elsewhere. - """ - if tp_left != tp_right: - return False - if tp_left.target_type == TargetType.OPERATOR_PRE_HOOK: - return tp_left.input_port_id == tp_right.input_port_id - return True - - def commands_are_equal( command_left: Union[PTInsertionCommand, PTSharedFnInsertionCommand], command_right: Union[PTInsertionCommand, PTSharedFnInsertionCommand], @@ -331,12 +315,10 @@ def commands_are_equal( return False if isinstance(command_right, PTInsertionCommand): - if not target_points_are_equal(command_left.target_point, command_right.target_point): + if command_left.target_point != command_right.target_point: return False elif isinstance(command_right, PTSharedFnInsertionCommand): - if not all( - target_points_are_equal(a, b) for a, b in zip(command_left.target_points, command_right.target_points) - ): + if not all(a == b for a, b in zip(command_left.target_points, command_right.target_points)): return False if ( command_right.target_points != command_left.target_points diff --git a/tests/torch/nncf_network/helpers.py b/tests/torch/nncf_network/helpers.py index c51897cb3d7..36a850e986d 100644 --- a/tests/torch/nncf_network/helpers.py +++ b/tests/torch/nncf_network/helpers.py @@ -64,6 +64,14 @@ def __init__(self, model_cls: Type[torch.nn.Module]): TRACE_VS_NODE_NAMES = {True: "CONV_NODES_NAMES", False: "NNCF_CONV_NODES_NAMES"} + @staticmethod + def get_input_port_id(target_type: TargetType, trace_parameters: bool) -> Optional[int]: + if target_type is TargetType.OPERATOR_PRE_HOOK: + return 0 + if trace_parameters and target_type in [TargetType.PRE_LAYER_OPERATION, TargetType.OPERATION_WITH_WEIGHTS]: + return 1 + return None + def create_pt_insertion_command( self, target_type: TargetType, @@ -75,7 +83,9 @@ def create_pt_insertion_command( ): attr_name = self.TRACE_VS_NODE_NAMES[trace_parameters] target_point = PTTargetPoint( - target_type=target_type, target_node_name=getattr(self.model_cls, attr_name)[0], input_port_id=0 + target_type=target_type, + target_node_name=getattr(self.model_cls, attr_name)[0], + input_port_id=self.get_input_port_id(target_type, trace_parameters), ) if fn is None: fn = DummyOpWithState("DUMMY_STATE") @@ -94,7 +104,13 @@ def create_pt_shared_fn_insertion_command( target_points = [] attr_name = self.TRACE_VS_NODE_NAMES[trace_parameters] for node_name in getattr(self.model_cls, attr_name): - target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0)) + target_points.append( + PTTargetPoint( + target_type=target_type, + target_node_name=node_name, + input_port_id=self.get_input_port_id(target_type, trace_parameters), + ) + ) if fn is None: fn = DummyOpWithState("DUMMY_STATE") return PTSharedFnInsertionCommand(