Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] NNCFNetwork.transformation_layout #2595

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions nncf/torch/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,35 @@ def get_input_shape_for_insertion_point(self, insertion_point: PTTargetPoint) ->
return quantizer_input_shape

def get_op_nodes_in_scope(self, scope: Scope) -> List[NNCFNode]:
"""
Returns all NNCFNodes inside the given scope.

:param scope: Given scope.
:return: All NNCFNodes inside the given scope.
"""
matching_graph_op_nodes = []
for scope_str, nodes_in_module in self._layer_name_vs_shared_nodes.items():
module_scope = Scope.from_str(scope_str)
if module_scope in scope:
matching_graph_op_nodes.extend(nodes_in_module)
return matching_graph_op_nodes

def get_op_nodes_with_scope(self, scope: Scope) -> List[NNCFNode]:
"""
Returns all NNCFNodes which share the given scope.

:param scope: Given scope.
:return: All NNCFNodes which share the given scope.
"""
return self._layer_name_vs_shared_nodes[str(scope)]
AlexanderDokuchaev marked this conversation as resolved.
Show resolved Hide resolved

def get_scope_by_node_name(self, node_name: NNCFNodeName) -> Scope:
"""
Returns a scope which corresponds to the given NNCF node name.

:param node_name: Given node name.
:return: A scope which corresponds to the given NNCF node name.
"""
matches = []
for node_id, scope_str in self._node_ids_vs_layer_names.items():
node = self.get_node_by_id(node_id)
Expand Down
26 changes: 25 additions & 1 deletion nncf/torch/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union
from typing import List, Optional, Union

import torch
from torch import Tensor

from nncf.common.graph.graph import NNCFNode
Expand Down Expand Up @@ -82,3 +83,26 @@ def create_shared_quantizer_insertion_command(
compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER,
priority=TransformationPriority.QUANTIZATION_PRIORITY,
)


def create_pt_insertion_command(
module: torch.nn.Module,
target_type: TargetType,
target_node_name: str,
priority: int,
input_port_id: Optional[int],
) -> PTInsertionCommand:
"""
Creates a PTInsertionCommand.

:param module: Torch module to insert.
:param target_type: Insertion command target type.
:param target_name: Insertion command target name.
:param priority: Insertion command priority.
:param input_port_id: Insertion command input port id.
:return: A PTInsertionCommand
"""
target_point = PTTargetPoint(
target_type=target_type, target_node_name=target_node_name, input_port_id=input_port_id
)
return PTInsertionCommand(point=target_point, fn=module, priority=priority)
6 changes: 3 additions & 3 deletions nncf/torch/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from enum import Enum
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Union

import torch

Expand Down Expand Up @@ -139,7 +139,7 @@ def __init__(
self,
point: PTTargetPoint,
fn: Callable,
priority: TransformationPriority = TransformationPriority.DEFAULT_PRIORITY,
priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY,
hooks_group_name: str = DEFAULT_HOOKS_GROUP_NAME,
):
super().__init__(TransformationType.INSERT, point)
Expand All @@ -164,7 +164,7 @@ def __init__(
fn: Callable,
op_unique_name: str,
compression_module_type: ExtraCompressionModuleType = ExtraCompressionModuleType.EXTERNAL_OP,
priority: TransformationPriority = TransformationPriority.DEFAULT_PRIORITY,
priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY,
hooks_group_name: str = DEFAULT_HOOKS_GROUP_NAME,
):
super().__init__(TransformationType.INSERT, None)
Expand Down
138 changes: 138 additions & 0 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from nncf.common.utils.debug import is_debug
from nncf.torch.debug import CombinedDebugInterface
from nncf.torch.debug import debuggable_forward
from nncf.torch.dynamic_graph.context import PreHookId
from nncf.torch.dynamic_graph.context import TracingContext
from nncf.torch.dynamic_graph.graph import DynamicGraph
from nncf.torch.dynamic_graph.graph import ShapeIgnoringTensorMetaComparator
Expand All @@ -60,16 +61,21 @@
from nncf.torch.dynamic_graph.wrappers import wrap_module_call
from nncf.torch.dynamic_graph.wrappers import wrap_parameters
from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME
from nncf.torch.external_hook import ExternalOpCallHook
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.graph_builder import GraphBuilder
from nncf.torch.graph.graph_builder import GraphConverter
from nncf.torch.graph.operator_metatypes import OPERATORS_WITH_WEIGHTS_METATYPES
from nncf.torch.graph.operator_metatypes import PTSplitMetatype
from nncf.torch.graph.transformations.command_creation import create_pt_insertion_command
from nncf.torch.graph.transformations.commands import DEFAULT_HOOKS_GROUP_NAME
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.layout import PTTransformationLayout
from nncf.torch.knowledge_distillation.knowledge_distillation_handler import KnowledgeDistillationLossHandler
from nncf.torch.layer_utils import _NNCFModuleMixin
from nncf.torch.module_operations import UpdateWeight
from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules
from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME
from nncf.torch.utils import compute_FLOPs_hook
Expand Down Expand Up @@ -778,6 +784,127 @@ def _collect_eval_op_scopes(self, model: nn.Module, dummy_forward_fn: Callable)
result.append(scope_in_model)
return result

def transformation_layout(self) -> PTTransformationLayout:
"""
Collects all hooks applied to the NNCFNetwork, converts them to insertion commands
and returns in PTTransformationLayout format. Default hooks group name is used in
recovered commands, so hooks group names specified during the model modification
become outdated.

:return: Transformation layout with all commands applied to the NNCFNetwork.
"""

def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str):
AlexanderDokuchaev marked this conversation as resolved.
Show resolved Hide resolved
"""
Check given external op call hook reference is correct.

:param hook: External op call hook to check correctness.
:param info: Info to log in case op call hook references are broken.
"""
assert hasattr(
self, hook._storage_name
), f"Storage name {hook._storage_name} is not registered. Info: {info}"
assert hook._storage_key in getattr(
self, hook._storage_name
), f"Key {hook._storage_key} is not registered in {hook._storage_name}. Info: {info}"

context_hooks = defaultdict(lambda: defaultdict(list))
transformation_layout = PTTransformationLayout()
nncf_graph = self.get_graph()
nncf_node_names_map = self.get_op_address_to_op_name_map()

# Collect pre/post layer and op with weights insertion commands
for nncf_module, module_scope in self.get_nncf_modules().items():
for ops, target_type in (
(nncf_module.pre_ops, TargetType.PRE_LAYER_OPERATION),
(nncf_module.post_ops, TargetType.POST_LAYER_OPERATION),
):
for priority, module in enumerate(ops.values()):
nodes_in_scope = nncf_graph.get_op_nodes_with_scope(module_scope)
# Several NNCFNodes means that current NNCFModule was called
# several times. Only one insertion command is required to
# call hook as much times as the current NNCFModule, therefore
# we use first correspondent NNCFNode.
nncf_node = nodes_in_scope[0]
command_target_type = target_type
if isinstance(module, UpdateWeight):
command_target_type = TargetType.OPERATION_WITH_WEIGHTS
module = module.op
if not isinstance(module, ExternalOpCallHook):
command = create_pt_insertion_command(
module, command_target_type, nncf_node.node_name, priority, None
)
transformation_layout.register(command)
continue

info = (
f"TargetType: {command_target_type}, nncf node name: {nncf_node.node_name},"
f" priority: {priority}, fn: {module}"
)
_check_external_call_hook_is_valid(module, info)

context_hooks[module._storage_name][module._storage_key].append(
(command_target_type, nncf_node.node_name, priority, module, None)
)

# Collect all pre/post hooks commands
for ops, target_type in (
(self._compressed_context._pre_hooks, TargetType.OPERATOR_PRE_HOOK),
(self._compressed_context._post_hooks, TargetType.OPERATOR_POST_HOOK),
):
for op_address, hooks in ops.items():
if isinstance(op_address, PreHookId):
input_port_id = op_address.input_port_id
op_address = op_address.op_address
else:
input_port_id = None
for priority, fn in enumerate(hooks.values()):
target_node_names = nncf_node_names_map[op_address]
# Operation address is unique for each module call
assert len(target_node_names) == 1
target_node_name = target_node_names[0]

if not isinstance(fn, ExternalOpCallHook):
command = create_pt_insertion_command(
fn, target_type, target_node_name, priority, input_port_id
)
transformation_layout.register(command)
continue

info = f"TargetType: {target_type}, op_address: {op_address}, priority: {priority}, fn: {fn}"
_check_external_call_hook_is_valid(fn, info)

context_hooks[fn._storage_name][fn._storage_key].append(
(target_type, target_node_name, priority, fn, input_port_id)
)

# Create shared fn insertion commands according to external hooks collected from
# pre/post layer, pre/post hooks and op with weights target points.
for module_type_name, storage in context_hooks.items():
for storage_key, call_hook_list_info in storage.items():
compression_module = getattr(self, module_type_name)[storage_key]
target_points = []
for target_type, target_node_name, priority, fn, input_port_id in call_hook_list_info:
target_points.append(PTTargetPoint(target_type, target_node_name, input_port_id=input_port_id))

if module_type_name == EXTERNAL_QUANTIZERS_STORAGE_NAME:
module_type = ExtraCompressionModuleType.EXTERNAL_QUANTIZER
elif module_type_name == EXTERNAL_OP_STORAGE_NAME:
module_type = ExtraCompressionModuleType.EXTERNAL_OP
else:
raise RuntimeError(f"Module type {module_type_name} is not supported")

command = PTSharedFnInsertionCommand(
target_points=target_points,
fn=compression_module,
op_unique_name=storage_key,
compression_module_type=module_type,
priority=priority,
)
transformation_layout.register(command)

return transformation_layout

def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress]:
"""
Returns map of NNCFGraph node names vs DynamicGraph operation addresses.
Expand All @@ -796,6 +923,17 @@ def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress]
retval[nncf_node.node_name] = op_address
return retval

def get_op_address_to_op_name_map(self) -> Dict[OperationAddress, NNCFNodeName]:
"""
Returns map of DynamicGraph operation addresses vs NNCFGraph node names.

:return: DynamicGraph operation addresses vs NNCFGraph node names.
"""
retval = defaultdict(list)
for nncf_node_name, op_address in self.get_node_to_op_address_mapping().items():
retval[op_address].append(nncf_node_name)
return retval

def set_compression_controller(self, ctrl: CompressionAlgorithmController):
self.compression_controller = ctrl

Expand Down
Loading
Loading