Skip to content

Commit

Permalink
Comments
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Apr 11, 2024
1 parent 0eca703 commit 15f5428
Show file tree
Hide file tree
Showing 8 changed files with 558 additions and 347 deletions.
6 changes: 1 addition & 5 deletions nncf/torch/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ def get_op_nodes_in_scope(self, scope: Scope) -> List[NNCFNode]:
return matching_graph_op_nodes

def get_op_node_in_scope(self, scope: Scope) -> List[NNCFNode]:
for scope_str, nodes_in_module in self._layer_name_vs_shared_nodes.items():
module_scope = Scope.from_str(scope_str)
if module_scope == scope:
return nodes_in_module
return []
return self._layer_name_vs_shared_nodes[str(scope)]

def get_scope_by_node_name(self, node_name: NNCFNodeName) -> Scope:
matches = []
Expand Down
49 changes: 38 additions & 11 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,19 +794,19 @@ def get_applied_transformation_layout(self) -> PTTransformationLayout:
:return: Transformation layout with all commands applied to the NNCFNetwork.
"""

def _create_pt_insert_command(module, target_type, target_node_name, priority, input_port_id):
target_point = PTTargetPoint(
target_type=target_type, target_node_name=target_node_name, input_port_id=input_port_id
)
return PTInsertionCommand(point=target_point, fn=module, priority=priority)

def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str):
"""
Check given external op call hook reference is correct.
:param hook: External op call hook to check correctness.
:param info: Info to log in case op call hook references are broken.
"""
assert hasattr(
self, hook._storage_name
), f"Storage name {hook._storage_name} is not registered. Info: {info}"
assert hook._storage_key in getattr(
self, hook._storage_name
), f"Storage key {hook._storage_key} is not registered. Info: {info}"
), f"Key {hook._storage_key} is not registered in {hook._storage_name}. Info: {info}"

context_hooks = defaultdict(lambda: defaultdict(list))
transformation_layout = PTTransformationLayout()
Expand All @@ -828,14 +828,16 @@ def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str):
command_target_type = TargetType.OPERATION_WITH_WEIGHTS
module = module.op
if not isinstance(module, ExternalOpCallHook):
command = _create_pt_insert_command(
command = create_pt_insertion_command(
module, command_target_type, nncf_node.node_name, priority, None
)
transformation_layout.register(command)
continue

info = f"TargetType: {command_target_type}, nncf node name: {nncf_node.node_name},"
f" priority: {priority}, fn: {module}"
info = (
f"TargetType: {command_target_type}, nncf node name: {nncf_node.node_name},"
f" priority: {priority}, fn: {module}"
)
_check_external_call_hook_is_valid(module, info)

context_hooks[module._storage_name][module._storage_key].append(
Expand All @@ -859,7 +861,9 @@ def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str):
target_node_name = target_node_names[0]

if not isinstance(fn, ExternalOpCallHook):
command = _create_pt_insert_command(fn, target_type, target_node_name, priority, input_port_id)
command = create_pt_insertion_command(
fn, target_type, target_node_name, priority, input_port_id
)
transformation_layout.register(command)
continue

Expand Down Expand Up @@ -1262,3 +1266,26 @@ def compression_module_type_to_attr_name(compression_module_type: ExtraCompressi
if compression_module_type == ExtraCompressionModuleType.EXTERNAL_OP:
return EXTERNAL_OP_STORAGE_NAME
raise nncf.ValidationError("Unknown extra module type")


def create_pt_insertion_command(
module: torch.nn.Module,
target_type: TargetType,
target_node_name: str,
priority: int,
input_port_id: Optional[int],
) -> PTInsertionCommand:
"""
Creates a PTInsertionCommand.
:param module: Torch module to insert.
:param target_type: Insertion command target type.
:param target_name: Insertion command target name.
:param priority: Insertion command priority.
:param input_port_id: Insertion command input port id.
:return: A PTInsertionCommand
"""
target_point = PTTargetPoint(
target_type=target_type, target_node_name=target_node_name, input_port_id=input_port_id
)
return PTInsertionCommand(point=target_point, fn=module, priority=priority)
176 changes: 41 additions & 135 deletions tests/torch/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
# limitations under the License.

import contextlib
import functools
import itertools
import numbers
from abc import ABC
from abc import abstractmethod
Expand All @@ -32,8 +30,6 @@

import nncf
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.config import NNCFConfig
from nncf.config.extractors import extract_algorithm_names
from nncf.config.structures import BNAdaptationInitArgs
Expand All @@ -43,13 +39,11 @@
from nncf.torch.dynamic_graph.io_handling import FillerInputInfo
from nncf.torch.dynamic_graph.operation_address import OperationAddress
from nncf.torch.dynamic_graph.scope import Scope
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.initialization import PTInitializingDataLoader
from nncf.torch.initialization import register_default_init_args
from nncf.torch.layer_utils import COMPRESSION_MODULES
from nncf.torch.layers import NNCF_MODULES_MAP
from nncf.torch.model_creation import create_compressed_model
from nncf.torch.module_operations import UpdateWeight
Expand Down Expand Up @@ -183,10 +177,6 @@ def nz_bias_num(self):

class TwoConvTestModel(nn.Module):
INPUT_SHAPE = [1, 1, 4, 4]
NNCF_CONV_NODES_NAMES = [
"TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0",
"TwoConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0",
]

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -214,113 +204,6 @@ def nz_weights_num(self):
def nz_bias_num(self):
return 2

@staticmethod
def create_pt_insertion_command(
target_type: TargetType, priority: TransformationPriority, fn=None, group: str = "default_group"
):
target_point = PTTargetPoint(
target_type=target_type, target_node_name=TwoConvTestModel.NNCF_CONV_NODES_NAMES[0], input_port_id=0
)
if fn is None:
fn = DummyOpWithState("DUMMY_STATE")
return PTInsertionCommand(point=target_point, fn=fn, priority=priority, hooks_group_name=group)

@staticmethod
def create_pt_shared_fn_insertion_command(
target_type: TargetType,
priority: TransformationPriority,
compression_module_type: ExtraCompressionModuleType,
fn=None,
group: str = "default_group",
op_unique_name: str = "UNIQUE_NAME",
):
target_points = []

for node_name in TwoConvTestModel.NNCF_CONV_NODES_NAMES:
target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0))
if fn is None:
fn = DummyOpWithState("DUMMY_STATE")
return PTSharedFnInsertionCommand(
target_points=target_points,
fn=fn,
compression_module_type=compression_module_type,
op_unique_name=op_unique_name,
priority=priority,
hooks_group_name=group,
)

AVAILABLE_TARGET_TYPES = (
TargetType.OPERATION_WITH_WEIGHTS,
TargetType.OPERATOR_PRE_HOOK,
TargetType.OPERATOR_POST_HOOK,
TargetType.PRE_LAYER_OPERATION,
TargetType.POST_LAYER_OPERATION,
)

@staticmethod
def get_command_builders():
return (
TwoConvTestModel.create_pt_insertion_command,
functools.partial(
TwoConvTestModel.create_pt_shared_fn_insertion_command,
compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP,
),
functools.partial(
TwoConvTestModel.create_pt_shared_fn_insertion_command,
compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER,
),
)

COMMAND_TYPES = [PTInsertionCommand, PTSharedFnInsertionCommand, PTSharedFnInsertionCommand]
PRIORITIES = (TransformationPriority.QUANTIZATION_PRIORITY, TransformationPriority.QUANTIZATION_PRIORITY.value + 1)

@classmethod
def get_all_available_commands(
cls, dummy_op_state, skip_model_transformer_unsupported=False
) -> TransformationLayout:
"""
Returns all possible commands to insert:
all target types x all command class x all compression module types x different priorities.
"""
layout = TransformationLayout()
for idx, (target_type, (command_builder, command_type), priority) in enumerate(
itertools.product(
cls.AVAILABLE_TARGET_TYPES, zip(cls.get_command_builders(), cls.COMMAND_TYPES), cls.PRIORITIES
)
):
if command_type is PTSharedFnInsertionCommand:
if skip_model_transformer_unsupported and target_type in [
TargetType.PRE_LAYER_OPERATION,
TargetType.POST_LAYER_OPERATION,
]:
continue
command = cls._create_command(
command_builder, target_type, priority, dummy_op_state, op_unique_name=f"UNIQUE_NAME_{idx}"
)
else:
command = cls._create_command(command_builder, target_type, priority, dummy_op_state)

layout.register(command)
return layout

@staticmethod
def _create_command(command_builder, target_type, priority, dummy_op_state, op_unique_name=None):
group_name = "CUSTOM_HOOKS_GROUP_NAME"

if DummyOpWithState.__name__ not in COMPRESSION_MODULES.registry_dict:
registered_dummy_op_cls = COMPRESSION_MODULES.register()(DummyOpWithState)
else:
registered_dummy_op_cls = DummyOpWithState
dummy_op = registered_dummy_op_cls(dummy_op_state)
if op_unique_name is None:
command = command_builder(target_type, priority, fn=dummy_op, group=group_name)
else:
command = command_builder(
target_type, priority, fn=dummy_op, group=group_name, op_unique_name=op_unique_name
)

return command


class LeNet(nn.Module):
INPUT_SIZE = 1, 32, 32
Expand Down Expand Up @@ -372,38 +255,61 @@ def from_state(cls, state: str):
return cls(state)


def target_points_are_equal(tp_original: PTTargetPoint, tp_recovered: PTTargetPoint):
if tp_original != tp_recovered:
def target_points_are_equal(tp_left: PTTargetPoint, tp_right: PTTargetPoint) -> bool:
"""
Returns True if given target points are equal and False elsewhere.
:param tp_left: The first target point.
:param tp_right: The second target point.
:return: True if given target points are equal and False elsewhere.
"""
if tp_left != tp_right:
return False
if tp_original.target_type == TargetType.OPERATOR_PRE_HOOK:
return tp_original.input_port_id == tp_recovered.input_port_id
if tp_left.target_type == TargetType.OPERATOR_PRE_HOOK:
return tp_left.input_port_id == tp_right.input_port_id
return True


def are_commands_equal(
command, applied_command, check_priority: bool = True, check_hooks_group_name: bool = True, check_fn_ref=True
):
if type(applied_command) is not type(command):
def commands_are_equal(
command_left: Union[PTInsertionCommand, PTSharedFnInsertionCommand],
command_right: Union[PTInsertionCommand, PTSharedFnInsertionCommand],
check_priority: bool = True,
check_hooks_group_name: bool = True,
check_fn_ref=True,
) -> bool:
"""
Returns True if given commands are equal and False elsewhere.
:param command_left: The first command.
:param command_right: The second command.
:param check_priority: Whether to check insertion priority or not.
:param check_hooks_group_name: Whether to check hooks group name or not.
:param check_fn_ref: Whether to check fn by reference or not.
:returns: True if given commands are equal and False elsewhere.
"""
if type(command_right) is not type(command_left):
return False

# Check reference to functions are equal.
if check_fn_ref and applied_command.fn is not command.fn:
if check_fn_ref and command_right.fn is not command_left.fn:
return False
if check_hooks_group_name and applied_command.hooks_group_name != command.hooks_group_name:
if check_hooks_group_name and command_right.hooks_group_name != command_left.hooks_group_name:
return False
if check_priority and applied_command.priority != command.priority:
if check_priority and command_right.priority != command_left.priority:
return False

if isinstance(applied_command, PTInsertionCommand):
if not target_points_are_equal(command.target_point, applied_command.target_point):
if isinstance(command_right, PTInsertionCommand):
if not target_points_are_equal(command_left.target_point, command_right.target_point):
return False
elif isinstance(applied_command, PTSharedFnInsertionCommand):
if not all(target_points_are_equal(a, b) for a, b in zip(command.target_points, applied_command.target_points)):
elif isinstance(command_right, PTSharedFnInsertionCommand):
if not all(
target_points_are_equal(a, b) for a, b in zip(command_left.target_points, command_right.target_points)
):
return False
if (
applied_command.target_points != command.target_points
or applied_command.op_name != command.op_name
or applied_command.compression_module_type != command.compression_module_type
command_right.target_points != command_left.target_points
or command_right.op_name != command_left.op_name
or command_right.compression_module_type != command_left.compression_module_type
):
return False
else:
Expand Down
Loading

0 comments on commit 15f5428

Please sign in to comment.