Skip to content

Commit

Permalink
[Torch] Drop PTQuantizerInsertionCommand (#2584)
Browse files Browse the repository at this point in the history
Preparation for #2531

### Changes

1) `PTQuantizerInsertionCommand` is removed and replaced with
create_quantizer_insertion_command function
2) `SharedFNInsertionCommand` updates with one new attribute:
compression_module_type
3) `ExtraOpCallHook` doesn't require context in constructor anymore
4) Multidevice support is moved from
`apply_quantizers_insertion_commands_transformation` to
`apply_insertion_transformation`

### Reason for changes

1) To make it easier to store and restore commands: less commands - less
amount of adapters are needed
2) To make it possible to express `PTQuantizerInsertionCommand` by
`SharedFNInsertionCommand`
3) To make it possible to create `ExtraOpCallHook` outside of the
`PTModelTransformer`
4) To unify multidevice support for all insertion operations

### Related tickets
2531

### Tests

1)`test_quantizer_insertion_transformation` is updated
2) -
3) `test_shared_fn_insertion_point` is updated
4) `test_pt_insertion_command` is introduced
  • Loading branch information
daniil-lyakhov authored Apr 2, 2024
1 parent fa8b702 commit f7a5660
Show file tree
Hide file tree
Showing 21 changed files with 403 additions and 217 deletions.
23 changes: 7 additions & 16 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple, Union

import torch

Expand All @@ -36,7 +36,9 @@
from nncf.quantization.range_estimator import RangeEstimatorParameters
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.graph import PTTargetPoint
from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand
from nncf.torch.graph.transformations.command_creation import create_quantizer_insertion_command
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.hardware.config import PTHWConfig
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
Expand Down Expand Up @@ -128,17 +130,6 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) -
target_type = PTMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type]
return PTTargetPoint(target_type, target_node_name, input_port_id=port_id)

@staticmethod
def create_quantizer_insertion_command(
nncf_graph: NNCFGraph,
target_point: PTTargetPoint,
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
) -> PTQuantizerInsertionCommand:
return PTMinMaxAlgoBackend._create_quantizer_insertion_command(
nncf_graph, target_point, quantizer_config, parameters
)

@staticmethod
def create_convert_insertion_command(
target_point: PTTargetPoint,
Expand Down Expand Up @@ -290,20 +281,20 @@ def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantiz
quantizer.scale = torch.nn.Parameter(parameters.input_high.data - quantizer.eps)

@staticmethod
def _create_quantizer_insertion_command(
def create_quantizer_insertion_command(
nncf_graph: NNCFGraph,
target_point: PTTargetPoint,
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
) -> PTQuantizerInsertionCommand:
) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]:
_, scale_shape, _ = PTMinMaxAlgoBackend._get_input_scale_shape(
nncf_graph, target_point, quantizer_config.per_channel
)

quantizer = PTMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_point.target_type
)
return PTQuantizerInsertionCommand(target_point, quantizer)
return create_quantizer_insertion_command(target_point, quantizer)

@staticmethod
def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[OperatorMetatype]:
Expand Down
8 changes: 3 additions & 5 deletions nncf/torch/external_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from typing import Any

from nncf.torch.dynamic_graph.context import TracingContext
from nncf.torch.dynamic_graph.context import get_current_context

EXTERNAL_OP_STORAGE_NAME = "external_op"

Expand All @@ -26,17 +26,15 @@ class ExternalOpCallHook:
the base module execution.
"""

def __init__(self, storage_name: str, context: TracingContext, storage_key: str):
def __init__(self, storage_name: str, storage_key: str):
"""
:param storage_name: Attribute name of a model NNCFInterface.
:param context: Current tracing context.
:param storage_key: Key to retrieve callable hook
"""
self._storage_name = storage_name
self._compressed_context = context
self._storage_key = storage_key

def __call__(self, *args: Any, **kwargs) -> Any:
replica = self._compressed_context.base_module_thread_local_replica
replica = get_current_context().base_module_thread_local_replica
storage = getattr(replica.nncf, self._storage_name)
return storage[self._storage_key](*args, **kwargs)
25 changes: 25 additions & 0 deletions nncf/torch/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

from torch import Tensor

from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.quantization.structs import NonWeightQuantizerId
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand
from nncf.torch.quantization.layers import BaseQuantizer


def create_bias_correction_command(node: NNCFNode, bias_value: Tensor) -> PTBiasCorrectionCommand:
Expand All @@ -40,3 +48,20 @@ def create_command_to_update_weight(node: NNCFNode, weight_value: Tensor) -> PTW
"""
target_point = PTTargetPoint(TargetType.LAYER, node.node_name)
return PTWeightUpdateCommand(target_point, weight_value)


def create_quantizer_insertion_command(
target_point: PTTargetPoint, quantizer: BaseQuantizer
) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]:
if target_point.type is TargetType.OPERATION_WITH_WEIGHTS:
return PTInsertionCommand(target_point, quantizer, TransformationPriority.QUANTIZATION_PRIORITY)

quantizer_id = NonWeightQuantizerId(target_point.target_node_name, target_point.input_port_id)
storage_key = str(quantizer_id)
return PTSharedFnInsertionCommand(
target_points=[target_point],
fn=quantizer,
op_unique_name=storage_key,
compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER,
priority=TransformationPriority.QUANTIZATION_PRIORITY,
)
27 changes: 8 additions & 19 deletions nncf/torch/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from typing import Any, Callable, Dict, List

import torch
Expand Down Expand Up @@ -150,45 +151,33 @@ def requires_graph_rebuild(self):
return self.priority == TransformationPriority.QUANTIZATION_PRIORITY


class ExtraCompressionModuleType(Enum):
EXTERNAL_QUANTIZER = 0
EXTERNAL_OP = 1


class PTSharedFnInsertionCommand(PTTransformationCommand):
def __init__(
self,
target_points: List[PTTargetPoint],
fn: Callable,
op_unique_name: str,
compression_module_type: ExtraCompressionModuleType = ExtraCompressionModuleType.EXTERNAL_OP,
priority: TransformationPriority = TransformationPriority.DEFAULT_PRIORITY,
hooks_group_name: str = DEFAULT_HOOKS_GROUP_NAME,
):
super().__init__(TransformationType.INSERT, None)
self.target_points = target_points
self.fn = fn
self.op_name = op_unique_name
self.compression_module_type = compression_module_type
self.priority = priority
self.hooks_group_name = hooks_group_name

def requires_graph_rebuild(self):
return True


class PTQuantizerInsertionCommand(PTTransformationCommand):
"""
Insertion quantizer operation to the models.
"""

def __init__(
self,
point: PTTargetPoint,
quantizer: "BaseQuantizer", # noqa: F821
hooks_group_name: str = DEFAULT_HOOKS_GROUP_NAME,
):
super().__init__(TransformationType.INSERT, point)
self.quantizer = quantizer
self.hooks_group_name = hooks_group_name

def requires_graph_rebuild(self):
return True


class PTModelExtractionWithFusedBiasCommand(PTCommand):
"""
Extracts sequence by name with node that contain fused bias.
Expand Down
9 changes: 6 additions & 3 deletions nncf/torch/model_graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.operator_metatypes import CONST_NOOP_METATYPES
from nncf.torch.dynamic_graph.context import PreHookId
from nncf.torch.external_hook import ExternalOpCallHook
from nncf.torch.graph import operator_metatypes as om
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.external_quantizer import ExternalQuantizerCallHook
from nncf.torch.quantization.layers import AsymmetricQuantizer
from nncf.torch.quantization.layers import BaseQuantizer
from nncf.torch.quantization.layers import SymmetricQuantizer

CONV_META_TYPES = [
Expand Down Expand Up @@ -295,7 +296,9 @@ def get_fake_quantizer(
hook_container = model.nncf._compressed_context._post_hooks.get(op_addr, {})

for call_hook in hook_container.values():
if isinstance(call_hook, ExternalQuantizerCallHook):
if isinstance(call_hook, ExternalOpCallHook):
storage = getattr(model.nncf, call_hook._storage_name)
return storage[call_hook._storage_key]
module = storage[call_hook._storage_key]
if isinstance(module, BaseQuantizer):
return module
return None
Loading

0 comments on commit f7a5660

Please sign in to comment.