Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Drop PTQuantizerInsertionCommand #2584

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple, Union

import torch

Expand All @@ -36,7 +36,9 @@
from nncf.quantization.range_estimator import RangeEstimatorParameters
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.graph import PTTargetPoint
from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand
from nncf.torch.graph.transformations.command_creation import create_quantizer_insertion_command
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.hardware.config import PTHWConfig
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
Expand Down Expand Up @@ -128,17 +130,6 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) -
target_type = PTMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type]
return PTTargetPoint(target_type, target_node_name, input_port_id=port_id)

@staticmethod
def create_quantizer_insertion_command(
nncf_graph: NNCFGraph,
target_point: PTTargetPoint,
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
) -> PTQuantizerInsertionCommand:
return PTMinMaxAlgoBackend._create_quantizer_insertion_command(
nncf_graph, target_point, quantizer_config, parameters
)

@staticmethod
def create_convert_insertion_command(
target_point: PTTargetPoint,
Expand Down Expand Up @@ -290,20 +281,20 @@ def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantiz
quantizer.scale = torch.nn.Parameter(parameters.input_high.data - quantizer.eps)

@staticmethod
def _create_quantizer_insertion_command(
def create_quantizer_insertion_command(
nncf_graph: NNCFGraph,
target_point: PTTargetPoint,
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
) -> PTQuantizerInsertionCommand:
) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]:
_, scale_shape, _ = PTMinMaxAlgoBackend._get_input_scale_shape(
nncf_graph, target_point, quantizer_config.per_channel
)

quantizer = PTMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_point.target_type
)
return PTQuantizerInsertionCommand(target_point, quantizer)
return create_quantizer_insertion_command(target_point, quantizer)

@staticmethod
def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[OperatorMetatype]:
Expand Down
8 changes: 3 additions & 5 deletions nncf/torch/external_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from typing import Any

from nncf.torch.dynamic_graph.context import TracingContext
from nncf.torch.dynamic_graph.context import get_current_context

EXTERNAL_OP_STORAGE_NAME = "external_op"

Expand All @@ -26,17 +26,15 @@ class ExternalOpCallHook:
the base module execution.
"""

def __init__(self, storage_name: str, context: TracingContext, storage_key: str):
def __init__(self, storage_name: str, storage_key: str):
"""
:param storage_name: Attribute name of a model NNCFInterface.
:param context: Current tracing context.
:param storage_key: Key to retrieve callable hook
"""
self._storage_name = storage_name
self._compressed_context = context
self._storage_key = storage_key

def __call__(self, *args: Any, **kwargs) -> Any:
replica = self._compressed_context.base_module_thread_local_replica
replica = get_current_context().base_module_thread_local_replica
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
storage = getattr(replica.nncf, self._storage_name)
return storage[self._storage_key](*args, **kwargs)
25 changes: 25 additions & 0 deletions nncf/torch/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

from torch import Tensor

from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.quantization.structs import NonWeightQuantizerId
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand
from nncf.torch.quantization.layers import BaseQuantizer


def create_bias_correction_command(node: NNCFNode, bias_value: Tensor) -> PTBiasCorrectionCommand:
Expand All @@ -40,3 +48,20 @@ def create_command_to_update_weight(node: NNCFNode, weight_value: Tensor) -> PTW
"""
target_point = PTTargetPoint(TargetType.LAYER, node.node_name)
return PTWeightUpdateCommand(target_point, weight_value)


def create_quantizer_insertion_command(
target_point: PTTargetPoint, quantizer: BaseQuantizer
) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]:
if target_point.type is TargetType.OPERATION_WITH_WEIGHTS:
return PTInsertionCommand(target_point, quantizer, TransformationPriority.QUANTIZATION_PRIORITY)

quantizer_id = NonWeightQuantizerId(target_point.target_node_name, target_point.input_port_id)
storage_key = str(quantizer_id)
return PTSharedFnInsertionCommand(
target_points=[target_point],
fn=quantizer,
op_unique_name=storage_key,
compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER,
priority=TransformationPriority.QUANTIZATION_PRIORITY,
)
8 changes: 8 additions & 0 deletions nncf/torch/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from typing import Any, Callable, Dict, List

import torch
Expand Down Expand Up @@ -150,19 +151,26 @@ def requires_graph_rebuild(self):
return self.priority == TransformationPriority.QUANTIZATION_PRIORITY


class ExtraCompressionModuleType(Enum):
EXTERNAL_QUANTIZER = 0
EXTERNAL_OP = 1


class PTSharedFnInsertionCommand(PTTransformationCommand):
def __init__(
self,
target_points: List[PTTargetPoint],
fn: Callable,
op_unique_name: str,
compression_module_type: ExtraCompressionModuleType = ExtraCompressionModuleType.EXTERNAL_OP,
priority: TransformationPriority = TransformationPriority.DEFAULT_PRIORITY,
hooks_group_name: str = DEFAULT_HOOKS_GROUP_NAME,
):
super().__init__(TransformationType.INSERT, None)
self.target_points = target_points
self.fn = fn
self.op_name = op_unique_name
self.compression_module_type = compression_module_type
self.priority = priority
self.hooks_group_name = hooks_group_name

Expand Down
126 changes: 69 additions & 57 deletions nncf/torch/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,31 @@

import copy
from collections import defaultdict
from typing import Callable, Dict, List, Tuple
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple

import torch
from torch import Tensor
from torch import nn
from torch.nn.parameter import Parameter

from nncf.common.graph.model_transformer import ModelTransformer
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.quantization.structs import NonWeightQuantizerId
from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTModelExtractionWithFusedBiasCommand
from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand
from nncf.torch.graph.transformations.layout import PTTransformationLayout
from nncf.torch.model_analyzer import get_potential_fused_node
from nncf.torch.module_operations import UpdateWeight
from nncf.torch.nncf_network import ExtraCompressionModuleType
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.nncf_network import PTInsertionPoint
from nncf.torch.nncf_network import compression_module_type_to_attr_name
from nncf.torch.quantization.external_quantizer import ExternalOpCallHook
from nncf.torch.quantization.external_quantizer import ExternalQuantizerCallHook
from nncf.torch.utils import get_model_device
from nncf.torch.utils import is_multidevice

Expand All @@ -49,12 +48,15 @@ class PTModelTransformer(ModelTransformer):
def __init__(self, model: NNCFNetwork):
super().__init__(model)

device = None
if not is_multidevice(model):
device = get_model_device(model)

self._command_transformation_ordered_pairs = [
(PTModelExtractionWithFusedBiasCommand, self._apply_extraction_with_fused_bias_transformations),
(PTInsertionCommand, self._apply_insertion_transformations),
(PTQuantizerInsertionCommand, self._apply_quantizer_insertion_transformations),
(PTInsertionCommand, partial(self._apply_insertion_transformations, device=device)),
(PTSharedFnInsertionCommand, partial(self._apply_shared_nodes_insertion, device=device)),
(PTBiasCorrectionCommand, self._apply_bias_correction_transformations),
(PTSharedFnInsertionCommand, self._apply_shared_nodes_insertion),
(PTWeightUpdateCommand, self._apply_weights_update_transformations),
]

Expand All @@ -78,12 +80,16 @@ def transform(self, transformation_layout: PTTransformationLayout) -> NNCFNetwor
return model

@staticmethod
def _apply_insertion_transformations(model: NNCFNetwork, transformations: List[PTInsertionCommand]) -> NNCFNetwork:
def _apply_insertion_transformations(
model: NNCFNetwork, transformations: List[PTInsertionCommand], device: Optional[torch.device]
) -> NNCFNetwork:
"""
Applies insertion transformations to the model.

:param model: Model to apply transformations.
:param transformations: List of the bias correction transformations.
:param device: Target device for the insertion functions. Applies only to
functions which are subclassed from torch.nn.Module. Do nothing in case device is None.
:return: A modified NNCFNetwork.
"""
node_to_op_address_mapping = model.nncf.get_node_to_op_address_mapping()
Expand All @@ -98,7 +104,11 @@ def _apply_insertion_transformations(model: NNCFNetwork, transformations: List[P
input_port_id=target_point.input_port_id,
replaced_modules=model.nncf.replace_modules,
)

fn = transformation_command.fn
if device is not None and isinstance(fn, torch.nn.Module):
fn.to(device)

if model.nncf.replace_modules and target_point.type is TargetType.OPERATION_WITH_WEIGHTS:
fn = UpdateWeight(fn)
tup = (fn, transformation_command)
Expand All @@ -113,21 +123,63 @@ def _apply_insertion_transformations(model: NNCFNetwork, transformations: List[P

@staticmethod
def _apply_shared_nodes_insertion(
model: NNCFNetwork, transformations: List[PTSharedFnInsertionCommand]
model: NNCFNetwork,
transformations: List[PTSharedFnInsertionCommand],
device: Optional[torch.device],
) -> NNCFNetwork:
compression_model_type = ExtraCompressionModuleType.EXTERNAL_OP
"""
Applies insertion of PTSharedFnInsertionCommand commands. For each command method inserts
a torch module to the NNCFNetwork and inserts call hooks for each command target points.

:param model: Model to apply transformations.
:param transformations: List of the bias correction transformations.
:param device: Target device for the insertion functions. Applies only to
functions which are subclassed from torch.nn.Module. Do nothing in case device is None.
:return: A modified NNCFNetwork.
"""
compression_type_vs_transformations = defaultdict(list)
for transformation in transformations:
compression_type_vs_transformations[transformation.compression_module_type].append(transformation)

for compression_module_type, transformations in compression_type_vs_transformations.items():
model = PTModelTransformer._apply_shared_node_insertion_with_compression_type(
model, transformations, device, compression_module_type
)
return model

@staticmethod
def _apply_shared_node_insertion_with_compression_type(
model: NNCFNetwork,
transformations: List[PTSharedFnInsertionCommand],
device: Optional[torch.device],
compression_module_type: ExtraCompressionModuleType,
):
"""
Does _apply_shared_nodes_insertion with specified compression model type which will be
used for each transformation command.

if not model.nncf.is_compression_module_registered(compression_model_type):
model.nncf.register_compression_module_type(compression_model_type)
:param model: Model to apply transformations.
:param transformations: List of the bias correction transformations.
:param device: Target device for the insertion functions. Applies only to
functions which are subclassed from torch.nn.Module. Do nothing in case device is None.
:param compression_module_type: Common compression module type for all commands.
:return: A modified NNCFNetwork.
"""
if not model.nncf.is_compression_module_registered(compression_module_type):
model.nncf.register_compression_module_type(compression_module_type)

insertion_commands: List[PTInsertionCommand] = []

for shared_command in transformations:
model.nncf.add_compression_module(shared_command.op_name, shared_command.fn, compression_model_type)
fn = shared_command.fn
if device is not None:
fn.to(device)

model.nncf.add_compression_module(shared_command.op_name, fn, compression_module_type)

for target_point in shared_command.target_points:
fn = ExternalOpCallHook(
EXTERNAL_OP_STORAGE_NAME, model.nncf.get_tracing_context(), shared_command.op_name
compression_module_type_to_attr_name(compression_module_type), shared_command.op_name
)
insertion_commands.append(
PTInsertionCommand(
Expand All @@ -138,47 +190,7 @@ def _apply_shared_nodes_insertion(
)
)

return PTModelTransformer._apply_insertion_transformations(model, insertion_commands)

@staticmethod
def _apply_quantizer_insertion_transformations(
model: NNCFNetwork, transformations: List[PTQuantizerInsertionCommand]
) -> NNCFNetwork:
"""
Applies quantizer insertion transformations on the model.

:param model: Model to apply transformations.
:param transformations: List of the OVQuantizerInsertionCommand transformations.
:return: Model with inserted FakeQuantize nodes.
"""
compression_model_type = ExtraCompressionModuleType.EXTERNAL_QUANTIZER

if not model.nncf.is_compression_module_registered(compression_model_type):
model.nncf.register_compression_module_type(compression_model_type)

insertion_commands: List[PTInsertionCommand] = []
device = None
if not is_multidevice(model):
device = get_model_device(model)

for transformation_command in transformations:
target_point: PTTargetPoint = transformation_command.target_point
quantizer_module = transformation_command.quantizer
if device is not None:
quantizer_module = quantizer_module.to(device)
fn = quantizer_module

if target_point.type is not TargetType.OPERATION_WITH_WEIGHTS:
quantizer_id = NonWeightQuantizerId(target_point.target_node_name, target_point.input_port_id)
storage_key = str(quantizer_id)
model.nncf.add_compression_module(storage_key, quantizer_module, compression_model_type)
fn = ExternalQuantizerCallHook(model.nncf.get_tracing_context(), storage_key)

insertion_commands.append(
PTInsertionCommand(target_point, fn, TransformationPriority.QUANTIZATION_PRIORITY)
)

return PTModelTransformer._apply_insertion_transformations(model, insertion_commands)
return PTModelTransformer._apply_insertion_transformations(model, insertion_commands, device)

@staticmethod
def _apply_extraction_with_fused_bias_transformations(
Expand Down
Loading
Loading