Skip to content

Commit

Permalink
Input port id is corrected for the commands
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Apr 22, 2024
1 parent 46eb120 commit 9e62cbf
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
22 changes: 2 additions & 20 deletions tests/torch/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from nncf.torch.dynamic_graph.scope import Scope
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.initialization import PTInitializingDataLoader
from nncf.torch.initialization import register_default_init_args
from nncf.torch.layers import NNCF_MODULES_MAP
Expand Down Expand Up @@ -287,21 +286,6 @@ def from_state(cls, state: str):
return cls(state)


def target_points_are_equal(tp_left: PTTargetPoint, tp_right: PTTargetPoint) -> bool:
"""
Returns True if given target points are equal and False elsewhere.
:param tp_left: The first target point.
:param tp_right: The second target point.
:return: True if given target points are equal and False elsewhere.
"""
if tp_left != tp_right:
return False
if tp_left.target_type == TargetType.OPERATOR_PRE_HOOK:
return tp_left.input_port_id == tp_right.input_port_id
return True


def commands_are_equal(
command_left: Union[PTInsertionCommand, PTSharedFnInsertionCommand],
command_right: Union[PTInsertionCommand, PTSharedFnInsertionCommand],
Expand Down Expand Up @@ -331,12 +315,10 @@ def commands_are_equal(
return False

if isinstance(command_right, PTInsertionCommand):
if not target_points_are_equal(command_left.target_point, command_right.target_point):
if command_left.target_point != command_right.target_point:
return False
elif isinstance(command_right, PTSharedFnInsertionCommand):
if not all(
target_points_are_equal(a, b) for a, b in zip(command_left.target_points, command_right.target_points)
):
if not all(a == b for a, b in zip(command_left.target_points, command_right.target_points)):
return False
if (
command_right.target_points != command_left.target_points
Expand Down
20 changes: 18 additions & 2 deletions tests/torch/nncf_network/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ def __init__(self, model_cls: Type[torch.nn.Module]):

TRACE_VS_NODE_NAMES = {True: "CONV_NODES_NAMES", False: "NNCF_CONV_NODES_NAMES"}

@staticmethod
def get_input_port_id(target_type: TargetType, trace_parameters: bool) -> Optional[int]:
if target_type is TargetType.OPERATOR_PRE_HOOK:
return 0
if trace_parameters and target_type in [TargetType.PRE_LAYER_OPERATION, TargetType.OPERATION_WITH_WEIGHTS]:
return 1
return None

def create_pt_insertion_command(
self,
target_type: TargetType,
Expand All @@ -75,7 +83,9 @@ def create_pt_insertion_command(
):
attr_name = self.TRACE_VS_NODE_NAMES[trace_parameters]
target_point = PTTargetPoint(
target_type=target_type, target_node_name=getattr(self.model_cls, attr_name)[0], input_port_id=0
target_type=target_type,
target_node_name=getattr(self.model_cls, attr_name)[0],
input_port_id=self.get_input_port_id(target_type, trace_parameters),
)
if fn is None:
fn = DummyOpWithState("DUMMY_STATE")
Expand All @@ -94,7 +104,13 @@ def create_pt_shared_fn_insertion_command(
target_points = []
attr_name = self.TRACE_VS_NODE_NAMES[trace_parameters]
for node_name in getattr(self.model_cls, attr_name):
target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0))
target_points.append(
PTTargetPoint(
target_type=target_type,
target_node_name=node_name,
input_port_id=self.get_input_port_id(target_type, trace_parameters),
)
)
if fn is None:
fn = DummyOpWithState("DUMMY_STATE")
return PTSharedFnInsertionCommand(
Expand Down

0 comments on commit 9e62cbf

Please sign in to comment.