Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into dl/narrow_range_to…
Browse files Browse the repository at this point in the history
…_qconfig
  • Loading branch information
daniil-lyakhov committed Jan 31, 2025
2 parents f007ffa + f0cb70c commit ea5d0fd
Show file tree
Hide file tree
Showing 26 changed files with 749 additions and 210 deletions.
21 changes: 15 additions & 6 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,18 +464,27 @@ def _reduce_out_of_place(self, x: List[Tensor]) -> List[Tensor]:


class MeanVarianceReducer(TensorReducerBase):
def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]:
raise NotImplementedError()
def _reduce_out_of_place(self, x: List[Tensor]) -> List[Tensor]:
x = x[0]
reduction_axes = self._get_reduction_axes(x)
variance = fns.var(x, reduction_axes)
return [fns.mean(variance)]


class MaxVarianceReducer(TensorReducerBase):
def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]:
raise NotImplementedError()
def _reduce_out_of_place(self, x: List[Tensor]) -> List[Tensor]:
x = x[0]
reduction_axes = self._get_reduction_axes(x)
variance = fns.var(x, reduction_axes)
return [fns.max(variance)]


class MeanAbsMaxReducer(TensorReducerBase):
def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]:
raise NotImplementedError()
def _reduce_out_of_place(self, x: List[Tensor]) -> List[Tensor]:
x = fns.abs(x[0])
reduction_axes = self._get_reduction_axes(x)
abs_max = fns.max(x, reduction_axes, keepdims=self._keepdims)
return [fns.mean(abs_max)]


class QuantileReducerBase(TensorReducerBase):
Expand Down
1 change: 0 additions & 1 deletion nncf/experimental/torch/fx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def compress_weights_impl(
"""
Implementation of the `compress_weights()` method for the Torch Fx backend.
"""

compression_algorithm = WeightCompression(
mode,
ratio,
Expand Down
44 changes: 18 additions & 26 deletions nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import operator
from collections import OrderedDict
from collections import defaultdict
from functools import partial
from functools import reduce
from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar

Expand Down Expand Up @@ -266,6 +265,14 @@ def __init__(
subset_size=gptq_params.subset_size,
scale_estimation=self._scale_estimation,
)
if self._scale_estimation:
scale_estimation_params = self._advanced_parameters.scale_estimation_params
self._scale_estimation_algo = ScaleEstimation(
scale_estimation_params.subset_size,
scale_estimation_params.initial_steps,
scale_estimation_params.scale_steps,
scale_estimation_params.weight_penalty,
)

self._data_aware_mixed_precision = (
self._sensitivity_metric != SensitivityMetric.WEIGHT_QUANTIZATION_ERROR and self._ratio != 1.0
Expand Down Expand Up @@ -616,18 +623,13 @@ def apply(
)
else:
if self._scale_estimation:
scale_estimation_params = self._advanced_parameters.scale_estimation_params
scales, zero_points = ScaleEstimation(
model,
self._backend_entity.name_to_node_mapping,
all_weight_params,
nodes_to_compress,
statistics,
scale_estimation_params.subset_size,
scale_estimation_params.initial_steps,
scale_estimation_params.scale_steps,
scale_estimation_params.weight_penalty,
).apply(model, graph)
scales, zero_points = self._scale_estimation_algo.apply(
model=model,
graph=graph,
all_weight_params=all_weight_params,
statistics=statistics,
backend_entity=self._backend_entity,
)

if self._lora_correction:
lora_correction_params = self._advanced_parameters.lora_correction_params
Expand Down Expand Up @@ -702,8 +704,6 @@ def get_matmul_input_to_output_nodes_map(
"""
matmul_input_to_output_nodes_map = defaultdict(list)
for node in matmul_nodes:
if node.layer_attributes.input_attributes["transpose"]: # It works only for OV
raise nncf.UnsupportedModelError("Transposed input is not supported")
act_node, output_port_id = self._get_activation_node_and_port(node, graph)
matmul_input_to_output_nodes_map[(act_node, output_port_id)].append(node)
return matmul_input_to_output_nodes_map
Expand Down Expand Up @@ -811,16 +811,6 @@ def _get_statistics_for_weights_compression(
:return: Collected statistics.
"""

def input_filter_func(point, port_id):
# For the floating-point statistics collected in POST_LAYER style,
# we also need to determine the output port id.
# For the cases when the layer has more than one (0) output port.
return (
self._algorithm_key in point.algorithm_to_tensor_collectors
and point.target_point.type == TargetType.POST_LAYER_OPERATION
and point.target_point.port_id == port_id
)

# For each node we store statistics in a WCTensorStatistics data-class. It contains the following fields:
# mean_values=[mean_value_1, ..., mean_value_n]
# shapes=[shape_1, ..., shape_n]
Expand All @@ -830,7 +820,9 @@ def input_filter_func(point, port_id):
for (act_node, output_port_id), matmul_nodes in matmul_input_to_output_nodes_map.items():
tensor_collectors = list(
statistic_points.get_algo_statistics_for_node(
act_node.node_name, partial(input_filter_func, port_id=output_port_id), self._algorithm_key
act_node.node_name,
self._backend_entity.get_filter_fn_for_statistics(output_port_id, self._algorithm_key),
self._algorithm_key,
)
)
# Statistics could be empty in case when the statistics is registered for another algorithm,
Expand Down
14 changes: 13 additions & 1 deletion nncf/quantization/algorithms/weight_compression/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@

from abc import ABC
from abc import abstractmethod
from typing import Dict, Iterable, List, Optional, Tuple, TypeVar
from typing import Callable, Dict, Iterable, List, Optional, Tuple, TypeVar

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.experimental.common.tensor_statistics.collectors import HAWQAggregator
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
Expand Down Expand Up @@ -234,6 +235,17 @@ def dump_parameters(
:param path: Optional list of the paths.
"""

@staticmethod
@abstractmethod
def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> Callable[[StatisticPoint], bool]:
"""
Returns backend-specific callable to filter statistic containers according to its statistic point.
:param activation_port_id: Activation port id for the statistic collection target node.
:param algorithm_key: Current algorithm key.
:return: Backend-specific callable to filter statistic containers according to its statistic point.
"""


class AWQAlgoBackend(WeightCompressionAlgoBackend):
@staticmethod
Expand Down
23 changes: 9 additions & 14 deletions nncf/quantization/algorithms/weight_compression/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,18 @@ class DataBasedCriterion(DataFreeCriterion, ABC):

@property
def available_backends(self) -> List[BackendType]:
return [BackendType.OPENVINO]
return [BackendType.OPENVINO, BackendType.TORCH]

def _set_backend_entity(self, model: TModel) -> None:
model_backend = get_backend(model)
if model_backend == BackendType.OPENVINO:
from nncf.quantization.algorithms.weight_compression.openvino_backend import OVMixedPrecisionAlgoBackend

self._backend_entity = OVMixedPrecisionAlgoBackend(model)
elif model_backend == BackendType.TORCH:
from nncf.quantization.algorithms.weight_compression.torch_backend import PTMixedPrecisionAlgoBackend

self._backend_entity = PTMixedPrecisionAlgoBackend()
else:
raise nncf.UnsupportedBackendError(
"Cannot return backend-specific entity because {} is not supported!".format(model_backend.value)
Expand Down Expand Up @@ -303,21 +307,12 @@ def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) -
def _get_statistics_for_node(
self, statistic_points: StatisticPointsContainer, node: NNCFNode, nncf_graph: NNCFGraph, stat_key: str
) -> List[Tensor]:
act_node, output_port_id = self._get_activation_node_and_port(node, nncf_graph)

def input_filter_func(point):
# For the floating-point statistics collected in POST_LAYER style,
# we also need to determine the output port id.
# For the cases when the layer has more than one (0) output port.
return (
self._algorithm_key in point.algorithm_to_tensor_collectors
and point.target_point.type == TargetType.POST_LAYER_OPERATION
and point.target_point.port_id == output_port_id
)

act_node, act_port_id = self._get_activation_node_and_port(node, nncf_graph)
stats = []
for tensor_collector in statistic_points.get_algo_statistics_for_node(
act_node.node_name, input_filter_func, self._algorithm_key
act_node.node_name,
self._backend_entity.get_filter_fn_for_statistics(act_port_id, self._algorithm_key),
self._algorithm_key,
):
statistics = tensor_collector.get_statistics()
for data in statistics.get_data().values():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Callable, Dict, Iterable, List, Optional, Tuple

import openvino as ov
from openvino.runtime import opset13 as opset
Expand All @@ -19,6 +19,7 @@
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.utils import get_reduction_axes
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.common.utils.caching import disable_results_caching
from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
Expand Down Expand Up @@ -109,6 +110,8 @@ def mean_statistic_collector(

@staticmethod
def get_activation_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
if node.layer_attributes.input_attributes["transpose"]:
raise nncf.UnsupportedModelError("Transposed input is not supported")
constant_ports = node.layer_attributes.get_const_port_ids()
activation_ports = [
e.input_port_id for e in nncf_graph.get_input_edges(node) if e.input_port_id not in constant_ports
Expand Down Expand Up @@ -348,6 +351,17 @@ def dump_parameters(
) -> None:
dump_parameters(model, parameters, algo_name, path)

@staticmethod
def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> Callable[[StatisticPoint], bool]:
def filter_func(point: StatisticPoint) -> bool:
return (
algorithm_key in point.algorithm_to_tensor_collectors
and point.target_point.type == TargetType.POST_LAYER_OPERATION
and point.target_point.port_id == activation_port_id
)

return filter_func


class OVAWQAlgoAlgoBackend(AWQAlgoBackend, OVWeightCompressionAlgoBackend):
@staticmethod
Expand Down
56 changes: 23 additions & 33 deletions nncf/quantization/algorithms/weight_compression/scale_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,17 @@
# limitations under the License.

from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, TypeVar
from typing import Dict, List, Optional, Tuple, TypeVar

import nncf
from nncf import Dataset
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.logging.track_progress import track
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic
from nncf.parameters import CompressWeightsMode
from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_normalized_weight_and_fp4_scale
Expand All @@ -45,70 +43,57 @@ class ScaleEstimation:

def __init__(
self,
model: TModel,
name_to_node_mapping: Dict[str, Any],
all_weight_params: List[WeightCompressionParameters],
nodes_to_compress: List[NNCFNode],
statistics: Dict[str, WCTensorStatistic],
subset_size: int = 32,
initial_steps: int = 5,
scale_steps: int = 10,
weight_penalty: float = -1.0,
):
"""
:param model: Model for applying algorithm.
:param name_to_node_mapping: Name to node mapping for updating node weights.
:param all_weight_params: List of all weight parameters.
:param nodes_to_compress: List of nodes for processing.
:param statistics: Input activation statistics for each node.
:param subset_size: The number of samples for scale estimation.
:param initial_steps: The number of the steps for absmax scale rectification.
:param scale_steps: The number of the steps for grid search scale rectification
from 1.0 to 1.0 - 0.05 * scale_step.
:param weight_penalty: coefficient for penalty between fp and compressed weights. If -1 then doesn't apply.
"""
super().__init__()
self.name_to_node_mapping = name_to_node_mapping
self._all_weight_params = all_weight_params
self._nodes_to_compress = nodes_to_compress
self._statistics = statistics
self._subset_size = subset_size
self._initial_steps = initial_steps
self._scale_steps = scale_steps
self._weight_penalty = weight_penalty

self._set_backend_entity(model)

@property
def available_backends(self) -> List[BackendType]:
return [BackendType.OPENVINO]
return [BackendType.OPENVINO, BackendType.TORCH]

def _set_backend_entity(self, model: TModel) -> None:
"""
Creates a helper class with a backed-specific logic of the algorithm.
:param model: Backend-specific input model.
:param all_weight_params: List of all weight parameters.
:param nodes_to_compress: List of nodes for processing.
:param activations: The input activations of the layers considered for compression.
"""

model_backend = get_backend(model)
if model_backend == BackendType.OPENVINO:
from nncf.quantization.algorithms.weight_compression.openvino_backend import OVWeightCompressionAlgoBackend

self._backend_entity = OVWeightCompressionAlgoBackend(model, self.name_to_node_mapping)
self._backend_entity = OVWeightCompressionAlgoBackend(model)
elif model_backend == BackendType.TORCH:
from nncf.quantization.algorithms.weight_compression.torch_backend import PTWeightCompressionAlgoBackend

self._backend_entity = PTWeightCompressionAlgoBackend()
else:
raise nncf.UnsupportedBackendError(
"Cannot return backend-specific AWQ entity because {} is not supported!".format(model_backend.value)
"Cannot return backend-specific Scale Estimation entity because {} is not supported!".format(
model_backend.value
)
)

def apply(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
all_weight_params: List[WeightCompressionParameters],
statistics: Dict[str, WCTensorStatistic],
backend_entity: Optional[WeightCompressionAlgoBackend] = None,
) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]:
"""
Estimates better scale for the int4 nodes in the model.
Expand All @@ -119,23 +104,28 @@ def apply(
:param model: Model for applying algorithm.
:param graph: Model graph.
:param all_weight_params: List of all weight parameters.
:param statistics: Input activation statistics for each node.
:param statistic_points: Statistic points with collected statistics values.
:param dataset: A representative dataset for the calibration process.
:param backend_entity: Weight compression algorithm backend.
:return: Two dictionaries for estimated scales and zero points for each weight name.
"""

self._backend_entity = backend_entity
if self._backend_entity is None:
self._set_backend_entity(model)
scales, zero_points = dict(), dict()

for wp in track(self._all_weight_params, description="Applying Scale Estimation"):
for wp in track(all_weight_params, description="Applying Scale Estimation"):
weight_name = wp.weight_name
node_name = wp.node_with_weight.node_name
config = wp.compression_config

if config.num_bits != 4 or node_name not in self._statistics:
if config.num_bits != 4 or node_name not in statistics:
scales[weight_name] = None
continue

stats = self._statistics[node_name]
stats = statistics[node_name]

weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph)
if len(weight_data) != 1: # not supported by the algorithm
Expand Down
Loading

0 comments on commit ea5d0fd

Please sign in to comment.