Skip to content

Commit

Permalink
WIP transformations command serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Feb 28, 2024
1 parent 2283d6f commit 922170d
Show file tree
Hide file tree
Showing 21 changed files with 472 additions and 48 deletions.
2 changes: 2 additions & 0 deletions nncf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
from nncf.parameters import SensitivityMetric as SensitivityMetric
from nncf.parameters import TargetDevice as TargetDevice
from nncf.quantization import QuantizationPreset as QuantizationPreset
from nncf.quantization import apply_transformations as apply_transformations
from nncf.quantization import compress_weights as compress_weights
from nncf.quantization import get_quantization_transformations as get_quantization_transformations
from nncf.quantization import quantize as quantize
from nncf.quantization import quantize_with_accuracy_control as quantize_with_accuracy_control
from nncf.quantization.advanced_parameters import (
Expand Down
2 changes: 2 additions & 0 deletions nncf/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
# limitations under the License.
"""Post-training quantization APIs."""
from nncf.common.quantization.structs import QuantizationPreset as QuantizationPreset
from nncf.quantization.quantize_model import apply_transformations as apply_transformations
from nncf.quantization.quantize_model import compress_weights as compress_weights
from nncf.quantization.quantize_model import get_quantization_transformations as get_quantization_transformations
from nncf.quantization.quantize_model import quantize as quantize
from nncf.quantization.quantize_model import quantize_with_accuracy_control as quantize_with_accuracy_control
25 changes: 24 additions & 1 deletion nncf/quantization/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from typing import List, Optional, TypeVar

from nncf import Dataset
from nncf.common.factory import ModelTransformerFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType

Expand All @@ -35,7 +37,6 @@ def available_backends(self) -> List[BackendType]:
:return: List of backends supported by the algorithm.
"""

@abstractmethod
def apply(
self,
model: TModel,
Expand All @@ -52,6 +53,28 @@ def apply(
:param dataset: A representative dataset for the calibration process.
:return: A resulting model.
"""
transformation_layout = self.get_transformation_layout(model, graph, statistic_points, dataset)
return self.apply_transformation_layout(transformation_layout)

@abstractmethod
def get_transformation_layout(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TransformationLayout:
"""
get_transformation_layout
"""

def apply_transformation_layout(self, model: TModel, transformation_layout: TransformationLayout) -> TModel:
"""
apply_transformation_layout
"""
model_transformer = ModelTransformerFactory.create(model)
transformed_model = model_transformer.transform(transformation_layout)
return transformed_model

@abstractmethod
def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
Expand Down
7 changes: 3 additions & 4 deletions nncf/quantization/algorithms/bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,15 @@ def _set_backend_entity(self, model: TModel) -> None:
"Cannot return backend-specific entity because {} is not supported!".format(model_backend.value)
)

def apply(
def get_transformation_layout(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
) -> TransformationLayout:
self._set_backend_entity(model)
main_transformations_layout = TransformationLayout()
main_model_transformer = ModelTransformerFactory.create(model)

model_copy = copy_model(model)
graph_copy = NNCFGraphFactory.create(model_copy)
Expand Down Expand Up @@ -202,7 +201,7 @@ def apply(
# to reduce memory usage during the algorithm's pipeline.
self._remove_unnecessary_stats(position, subgraphs_data)

return main_model_transformer.transform(main_transformations_layout)
return main_transformations_layout

def _is_node_correctable(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
"""
Expand Down
9 changes: 3 additions & 6 deletions nncf/quantization/algorithms/channel_alignment/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from nncf import Dataset
from nncf.common.factory import CommandCreatorFactory
from nncf.common.factory import ModelTransformerFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.patterns import GraphPattern
Expand Down Expand Up @@ -90,15 +89,14 @@ def _set_backend_entity(self, model: TModel) -> None:

self._backend_entity = OVChannelAlignmentAlgoBackend()

def apply(
def get_transformation_layout(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
) -> TransformationLayout:
self._set_backend_entity(model)
model_transformer = ModelTransformerFactory.create(model)
transformation_layout = TransformationLayout()

def filter_func(point: StatisticPoint) -> bool:
Expand Down Expand Up @@ -168,8 +166,7 @@ def filter_func(point: StatisticPoint) -> bool:
command = command_creator.create_command_to_insert_bias(container.op, container.bias)
transformation_layout.register(command)

transformed_model = model_transformer.transform(transformation_layout)
return transformed_model
return transformation_layout

@staticmethod
def _align_means(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ def _set_backend_entity(self, model: TModel) -> None:
"Cannot return backend-specific entity because {} is not supported!".format(model_backend.value)
)

def apply(
def get_transformation_layout(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
) -> TransformationLayout:
self._set_backend_entity(model)

model_transformer = ModelTransformerFactory.create(model)
Expand Down Expand Up @@ -185,9 +185,7 @@ def apply(
transformation_layout = TransformationLayout()
for node, bias_value in node_and_new_bias_value:
transformation_layout.register(self._backend_entity.create_bias_correction_command(node, bias_value, graph))
transformed_model = model_transformer.transform(transformation_layout)

return transformed_model
return transformation_layout

@staticmethod
def _get_bias_shift_magnitude(current_bias_value: Tensor, updated_bias_value: Tensor) -> float:
Expand Down
11 changes: 3 additions & 8 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import nncf
from nncf import Dataset
from nncf.common.factory import ModelTransformerFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
Expand Down Expand Up @@ -803,15 +802,14 @@ def _get_quantization_points_overflow_fix(
output.update(nodes)
return output

def apply(
def get_transformation_layout(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
) -> TransformationLayout:
transformation_layout = TransformationLayout()
model_transformer = ModelTransformerFactory.create(model)
quantization_target_points, unified_scale_groups = self._get_quantization_target_points(model, graph)
quantization_points_overflow_fix = self._get_quantization_points_overflow_fix(
self._overflow_fix, quantization_target_points, graph
Expand Down Expand Up @@ -895,10 +893,7 @@ def filter_func(point: StatisticPoint) -> bool:
graph, quantization_target_point, qconfig, parameters
)
transformation_layout.register(command)
if not transformation_layout.transformations:
nncf_logger.info("The model has no operations to apply quantization.")
quantized_model = model_transformer.transform(transformation_layout)
return quantized_model
return transformation_layout

def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
self._set_backend_entity(model)
Expand Down
10 changes: 5 additions & 5 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,22 +283,22 @@ def _create_quantizer(
quantizer = quantizer_cls(quantizer_spec)

# Fill it with minmax
PTMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters)
PTMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape)
return quantizer

@staticmethod
def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters) -> None:
def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters, scale_shape) -> None:
if isinstance(quantizer, AsymmetricQuantizer):
quantizer.input_low = torch.nn.Parameter(parameters.input_low.data)
quantizer.input_low = torch.nn.Parameter(parameters.input_low.data.reshape(scale_shape))
input_range = parameters.input_high - parameters.input_low
# Subtract eps from the input_range to make quantizer parameters equal to
# original parameters on the forward call.
quantizer.input_range = torch.nn.Parameter(input_range.data - quantizer.eps)
quantizer.input_range = torch.nn.Parameter((input_range.data - quantizer.eps).reshape(scale_shape))
else:
quantizer.signed = bool(torch.any(parameters.input_low.data < 0))
# Subtract eps from the scale to make quantizer parameters equal to
# original parameters on the forward call.
quantizer.scale = torch.nn.Parameter(parameters.input_high.data - quantizer.eps)
quantizer.scale = torch.nn.Parameter((parameters.input_high.data - quantizer.eps).reshape(scale_shape))

@staticmethod
def _create_quantizer_insertion_command(
Expand Down
19 changes: 16 additions & 3 deletions nncf/quantization/algorithms/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def __init__(self, pipeline_steps: List[PipelineStep]):
:param pipeline_steps: A sequence of pipeline steps to be executed in order.
"""
self._pipeline_steps = pipeline_steps
self._algorithms_applied = False
self._transformation_layout_list = []

@property
def pipeline_steps(self) -> List[PipelineStep]:
Expand Down Expand Up @@ -114,12 +116,22 @@ def run_step(
pipeline_steps = self._remove_unsupported_algorithms(get_backend(model))
pipeline_step = pipeline_steps[step_index]
for algorithm in pipeline_step[:-1]:
current_model = algorithm.apply(current_model, current_graph, step_statistics)
current_model = self._apply_algoritm(algorithm, current_model, step_statistics, current_graph)
current_graph = NNCFGraphFactory.create(current_model)
current_model = pipeline_step[-1].apply(current_model, current_graph, step_statistics)

current_model = self._apply_algoritm(pipeline_step[-1], current_model, step_statistics, current_graph)
return current_model

def _apply_algoritm(
self,
algorithm: Algorithm,
current_model: TModel,
step_statistics: StatisticPointsContainer,
current_graph: NNCFGraph,
) -> TModel:
transformation_layout = algorithm.get_transformation_layout(current_model, current_graph, step_statistics)
self._transformation_layout_list.extend(transformation_layout.transformations)
return algorithm.apply_transformation_layout(current_model, transformation_layout)

def run_from_step(
self,
model: TModel,
Expand Down Expand Up @@ -165,6 +177,7 @@ def run_from_step(

step_graph = None # We should rebuild the graph for the next pipeline step

self._algorithms_applied = True
return step_model

def get_statistic_points_for_step(
Expand Down
15 changes: 15 additions & 0 deletions nncf/quantization/algorithms/post_training/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from nncf import Dataset
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.quantization.structs import QuantizationPreset
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
Expand Down Expand Up @@ -92,6 +93,20 @@ def available_backends(self) -> List[BackendType]:
def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
return self._pipeline.get_statistic_points_for_step(0, model, graph)

def get_transformation_layout(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TransformationLayout:
if not self._pipeline._algorithms_applied:
self.apply(model, graph, statistic_points, dataset)
return self.get_transformation_layout(model, graph, statistic_points, dataset)
transformation_layout = TransformationLayout()
transformation_layout.transformations.extend(self._pipeline._transformation_layout_list)
return transformation_layout

def apply(
self,
model: TModel,
Expand Down
9 changes: 2 additions & 7 deletions nncf/quantization/algorithms/smooth_quant/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import nncf
from nncf import Dataset
from nncf.common.factory import ModelTransformerFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
Expand Down Expand Up @@ -91,18 +90,17 @@ def _set_backend_entity(self, model: TModel) -> None:
"Cannot return backend-specific entity because {} is not supported!".format(model_backend.value)
)

def apply(
def get_transformation_layout(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
) -> TransformationLayout:
self._set_backend_entity(model)
alpha_map = self._get_alpha_map()

nodes_to_smooth_data = self._get_nodes_to_smooth_data(graph, alpha_map.keys())
model_transformer = ModelTransformerFactory.create(model)
transformation_layout = TransformationLayout()

node_groups = self._group_nodes_by_source(nodes_to_smooth_data, graph)
Expand Down Expand Up @@ -169,9 +167,6 @@ def apply(
)
transformation_layout.register(scale_insertion_command)

transformed_model = model_transformer.transform(transformation_layout)
return transformed_model

@staticmethod
def _calculate_scale_and_ratio(
activations: Tensor, weights: Tensor, alpha: float, quantile: Optional[float] = 0.1
Expand Down
12 changes: 11 additions & 1 deletion nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,30 @@
from nncf.torch.graph.transformations.command_creation import create_command_to_update_weight
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.layer_utils import COMPRESSION_MODULES
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.tensor_statistics.collectors import PTAbsMaxReducer
from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor

COMPRESSION_MODULES.register()


class SQMultiply(torch.nn.Module):
def __init__(self, scale_value):
def __init__(self, scale_value=1.0):
super().__init__()
self._scale_value = scale_value

def forward(self, x):
return torch.mul(x, self._scale_value)

def get_state(self):
return {}

@classmethod
def from_state(cls, state):
return cls()


PT_PRE_LAYER_TARGET_TYPE = TargetType.OPERATOR_PRE_HOOK

Expand Down
14 changes: 13 additions & 1 deletion nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.scopes import should_consider_scope
Expand Down Expand Up @@ -365,7 +366,18 @@ def do_compression(
return transformed_model

def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
pass
return StatisticPointsContainer()

def get_transformation_layout(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TransformationLayout:
raise NotImplementedError(
"get_transformation_layout is not implemented yet for the weights compression algorithm."
)

def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[NNCFNode, int]:
"""
Expand Down
Loading

0 comments on commit 922170d

Please sign in to comment.