Skip to content

Commit

Permalink
Shared Op insertion is implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 8, 2023
1 parent 944b5cd commit 84b0a53
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 25 deletions.
1 change: 0 additions & 1 deletion nncf/quantization/algorithms/smooth_quant/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from nncf.common.factory import ModelTransformerFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
Expand Down
4 changes: 1 addition & 3 deletions nncf/quantization/algorithms/smooth_quant/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@

from abc import ABC
from abc import abstractmethod
from typing import Dict, List, Optional, Tuple, TypeVar
from typing import List, Optional, Tuple, TypeVar

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector

TModel = TypeVar("TModel")
Expand Down
6 changes: 2 additions & 4 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple

import numpy as np
import openvino.runtime as ov
import torch

from nncf.common.graph import NNCFGraph
Expand All @@ -22,7 +21,6 @@
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.openvino.graph.node_utils import get_channel_agnostic_reduction_axes
from nncf.openvino.graph.transformations.commands import OVMultiplyInsertionCommand
from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand
Expand Down Expand Up @@ -133,7 +131,7 @@ def weight_update_command(node_with_weight: NNCFNode, weight_value: np.ndarray)
def scale_insertion_command(
source_node: NNCFNode, scale_value: np.ndarray, port_id: int, nodes: List[NNCFNode], scale_node_name: str
) -> OVMultiplyInsertionCommand:
return multiply_insertion_command(source_node, scale_value)
return multiply_insertion_command(nodes, scale_value, scale_node_name, port_id)

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: Tuple[int, ...]) -> int:
Expand Down
20 changes: 14 additions & 6 deletions nncf/torch/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand

Expand Down Expand Up @@ -48,12 +49,19 @@ def create_command_to_update_weight(node: NNCFNode, weight_value: Tensor) -> PTW


def multiply_insertion_command(
source_node: NNCFNode,
scale_value: Tensor,
target_nodes: List[NNCFNode], scale_value: Tensor, scale_node_name: str, input_port_id: int
) -> PTInsertionCommand:
target_point = PTTargetPoint(TargetType.OPERATOR_POST_HOOK, source_node.node_name)
commands = []
for target_node in target_nodes:
target_point = PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node.node_name, input_port_id=input_port_id)
commands.append(PTInsertionCommand(target_point, None, priority=TransformationPriority.OP_INSERTION_PRIORITY))

def multiply_fn(tensor):
return torch.mul(tensor, scale_value)
class SQMultiply(torch.nn.Module):
def __init__(self, scale_value):
super().__init__()
self._scale_value = scale_value

return PTInsertionCommand(target_point, multiply_fn, priority=TransformationPriority.OP_INSERTION_PRIORITY)
def forward(self, x):
return torch.mul(x, self._scale_value)

return PTSharedFnInsertionCommand(commands, SQMultiply(scale_value), scale_node_name)
27 changes: 26 additions & 1 deletion nncf/torch/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, List

import torch

Expand Down Expand Up @@ -161,6 +161,31 @@ def requires_graph_rebuild(self):
]


class PTSharedFnInsertionCommand(PTTransformationCommand):
def __init__(
self,
target_commands: List[PTInsertionCommand],
fn: Callable,
op_unique_name: str,
):
super().__init__(TransformationType.INSERT, None)
self.target_commands = target_commands
self.fn = fn
self.op_name = op_unique_name

def union(self, other: "PTTransformationCommand") -> "PTTransformationCommand":
# TODO: keep all TransformationCommands atomic, refactor TransformationLayout instead
raise NotImplementedError()

def requires_graph_rebuild(self):
"""
Return boolean flag to rebuild graph of model.
:return: Boolean flag.
"""
return True


class PTQuantizerInsertionCommand(PTTransformationCommand):
"""
Insertion quantizer operation to the models.
Expand Down
35 changes: 31 additions & 4 deletions nncf/torch/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTModelExtractionWithFusedBiasCommand
from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand
from nncf.torch.graph.transformations.layout import PTTransformationLayout
Expand All @@ -33,6 +34,8 @@
from nncf.torch.nncf_network import ExtraCompressionModuleType
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.nncf_network import PTInsertionPoint
from nncf.torch.quantization.external_quantizer import EXTERNAL_OP_STORAGE_NAME
from nncf.torch.quantization.external_quantizer import ExternalOpCallHook
from nncf.torch.quantization.external_quantizer import ExternalQuantizerCallHook


Expand All @@ -49,6 +52,7 @@ def __init__(self, model: NNCFNetwork):
(PTInsertionCommand, self._apply_insertion_transformations),
(PTQuantizerInsertionCommand, self._apply_quantizer_insertion_transformations),
(PTBiasCorrectionCommand, self._apply_bias_correction_transformations),
(PTSharedFnInsertionCommand, self._apply_shared_nodes_insertion),
(PTWeightUpdateCommand, self._apply_weights_update_transformations),
]

Expand Down Expand Up @@ -105,6 +109,29 @@ def _apply_insertion_transformations(model: NNCFNetwork, transformations: List[P

return model

@staticmethod
def _apply_shared_nodes_insertion(
model: NNCFNetwork, transformations: List[PTSharedFnInsertionCommand]
) -> NNCFNetwork:
compression_model_type = ExtraCompressionModuleType.EXTERNAL_OP

if not model.nncf.is_compression_module_registered(compression_model_type):
model.nncf.register_compression_module_type(compression_model_type)

insertion_commands: List[PTInsertionCommand] = []

for command in transformations:
op_id = (
command.op_name + f"[{';'.join([tp.target_point.target_node_name for tp in command.target_commands])}]"
)
model.nncf.add_compression_module(op_id, command.fn, compression_model_type)

for command in command.target_commands:
command.fn = ExternalOpCallHook(EXTERNAL_OP_STORAGE_NAME, model.nncf.get_tracing_context(), op_id)
insertion_commands.append(command)

return PTModelTransformer._apply_insertion_transformations(model, insertion_commands)

@staticmethod
def _apply_quantizer_insertion_transformations(
model: NNCFNetwork, transformations: List[PTQuantizerInsertionCommand]
Expand All @@ -130,7 +157,7 @@ def _apply_quantizer_insertion_transformations(
if target_point.type is not TargetType.OPERATION_WITH_WEIGHTS:
quantizer_id = NonWeightQuantizerId(target_point.target_node_name, target_point.input_port_id)
storage_key = str(quantizer_id)
model.nncf.add_compression_module(storage_key, transformation_command.quantizer, compression_model_type)
model.nncf.add_compression_module(storage_key, fn, compression_model_type)
fn = ExternalQuantizerCallHook(model.nncf.get_tracing_context(), storage_key)

insertion_commands.append(
Expand Down Expand Up @@ -212,9 +239,9 @@ def update_parameter(target_node_name: str, parameter_name: str, new_value: Tens
:param new_value: New parameter value.
:param model: The model.
"""
node = model.nncf.get_containing_module(target_node_name)
assert hasattr(node, parameter_name)
setattr(node, parameter_name, torch.nn.parameter.Parameter(new_value))
module = model.nncf.get_containing_module(target_node_name)
parameter: torch.nn.parameter.Parameter = getattr(module, parameter_name)
parameter.data = new_value


def extraction_potential_fused_modules(node_name: str, model: NNCFNetwork) -> nn.Sequential:
Expand Down
4 changes: 4 additions & 0 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from nncf.torch.layer_utils import _NNCFModuleMixin
from nncf.torch.nested_objects_traversal import objwalk
from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules
from nncf.torch.quantization.external_quantizer import EXTERNAL_OP_STORAGE_NAME
from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME
from nncf.torch.utils import compute_FLOPs_hook
from nncf.torch.utils import get_all_modules_by_type
Expand Down Expand Up @@ -117,6 +118,7 @@ def __hash__(self):

class ExtraCompressionModuleType(Enum):
EXTERNAL_QUANTIZER = 0
EXTERNAL_OP = 1


class NNCFNetworkInterface(torch.nn.Module):
Expand Down Expand Up @@ -554,6 +556,8 @@ def _compression_module_type_to_attr_name(compression_module_type: ExtraCompress
"""
if compression_module_type == ExtraCompressionModuleType.EXTERNAL_QUANTIZER:
return EXTERNAL_QUANTIZERS_STORAGE_NAME
if compression_module_type == ExtraCompressionModuleType.EXTERNAL_OP:
return EXTERNAL_OP_STORAGE_NAME
raise RuntimeError("Unknown extra module type")

def sort_compression_modules(self, compression_module_type: ExtraCompressionModuleType):
Expand Down
24 changes: 18 additions & 6 deletions nncf/torch/quantization/external_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

from nncf.torch.dynamic_graph.context import TracingContext
from nncf.torch.quantization.debug_interface import QuantizationDebugInterface

EXTERNAL_QUANTIZERS_STORAGE_NAME = "external_quantizers"
EXTERNAL_OP_STORAGE_NAME = "external_op"
EXTERNAL_QUANTIZERS_STORAGE_PREFIX = "_nncf." + EXTERNAL_QUANTIZERS_STORAGE_NAME


class ExternalQuantizerCallHook:
class ExternalOpCallHook:
def __init__(self, storage_name, context, storage_key):
self._storage_name = storage_name
self._compressed_context = context
self._storage_key = storage_key

def __call__(self, *args: Any, **kwargs) -> Any:
replica = self._compressed_context.base_module_thread_local_replica
storage = getattr(replica.nncf, self._storage_name)
return storage[self._storage_key](*args, **kwargs)


class ExternalQuantizerCallHook(ExternalOpCallHook):
"""
Cannot simply register the quantizer module as a callable hook, since we need to call
a thread-local version of the quantizer module during base module execution.
Expand All @@ -28,13 +43,10 @@ def __init__(
quantizer_storage_key: str,
debug_interface: QuantizationDebugInterface = None,
):
self.compressed_context = context
self.quantizer_storage_key = quantizer_storage_key
super().__init__(EXTERNAL_QUANTIZERS_STORAGE_NAME, context, quantizer_storage_key)
self.debug_interface = debug_interface

def __call__(self, *args, **kwargs):
if self.debug_interface is not None:
self.debug_interface.register_activation_quantize_call(str(self.quantizer_storage_key))
replica = self.compressed_context.base_module_thread_local_replica
storage = getattr(replica.nncf, EXTERNAL_QUANTIZERS_STORAGE_NAME)
return storage[self.quantizer_storage_key](*args, **kwargs)
super().__call__(*args, **kwargs)

0 comments on commit 84b0a53

Please sign in to comment.