Skip to content

Commit

Permalink
[PTQ] Add support of arbitrary batch size for PTQ (#2197)
Browse files Browse the repository at this point in the history
### Changes
Add a new advanced bool option for quantization -
`batchwise_statistics`.
When set to _True_ then statistics collection for supported algorithms
(see below) are calculated with the assumption that the 0-axis of a
tensor is a batch axis.
If the value is False then statistics collection for algorithms is
calculated with an assumption that the tensor has no batch axis.
If set to None statistics collection logic adapts based on the
batch_size of the provided dataset.

These adjustments in statistical computation apply specifically to
MinMax, ChannelAlighnment algorithms.

During the validation of proposed changes on a wide scope of models,
some limitations were observed - if a model contains specific operations
that output in a way that a tensor batch axis starts to contain no batch
meaning anymore, then the statistics after such operations are collected
not precisely.

The handling of such cases is introduced and determined by a warning
message to a user with a recommendation using batch size = 1 for a
specific model or set to False `batchwise_statistics` option.

The torch sample for mobilenet_v2 was updated with `batch_size=128`
value with a new recalculated `subset_size`.
The conformance test was updated with new options `batch_size` and
`dynamic_batch_shape`.
Calibrate.py was updated with a new option `batch_size`.

Algorithm support batch_size > 1:


Algorithm | Do results depend on batch_size? | Comments
-- | -- | --
MinMax | relatively depends | Relatively means that results are
dependant on the correctness of the utilized assumption that batch lays
on the 0-axis. To overcome there is a need to have batch axis
determination algorithm
FastBiascCorrection | Yes | Incorrect statistics calculation with no
regarding batch axis in an aggregator. Need to have batch axis
determination algorithm
BiasCorrection | Yes | Incorrect statistics calculation with no
regarding batch axis in an aggregator. Need to have batch axis
determination algorithm
ChannelAlighnment | No | Checked on models from conformance test:
**mobilenet_v2, mobilenet_v3**
SmoothQuant | No | Checked on models from conformance test: **levit_128,
visformer_small**
PostTrainingQuantization | Yes | Need to have batch axis determination
algorithm




### Reason for changes

Speeding up statistics collection.
SpeedUp on mobilenet_v2 sample (local measurments): 

Backend | bs=1 (sec) | bs=16 (sec) | bs=128 (sec)
-- | -- | -- | --
Torch | 24 | 4 | 4
Torch CUDA | 20 | 1 | 1
OpenVINO | 9 | 4 | 5
ONNX | 17 | 11 | 12

Extend usage scenarios.

### Related tickets

121650

### Tests
Old tests were updated accordingly.
New test added:
test_tensor_collector_batch_size
test_min_max
  • Loading branch information
kshpv authored Mar 22, 2024
1 parent 30b8d9a commit b7ba5ad
Show file tree
Hide file tree
Showing 72 changed files with 1,395 additions and 724 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def get_model_size(ir_path: Path, m_type: str = "Mb", verbose: bool = True) -> f
]
),
)
val_data_loader = torch.utils.data.DataLoader(val_dataset)
batch_size = 128
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)

torch_model = models.mobilenet_v2(num_classes=DATASET_CLASSES)
torch_model = load_checkpoint(torch_model)
Expand Down Expand Up @@ -140,8 +141,10 @@ def transform_fn(data_item: Tuple[torch.Tensor, int], device: torch.device) -> t
# item and prepare model input data. The quantize method uses a small subset
# (default: 300 samples) of the calibration dataset.

# Recalculation default subset_size parameter based on batch_size.
subset_size = 300 // batch_size
calibration_dataset = nncf.Dataset(val_data_loader, partial(transform_fn, device=device))
torch_quantized_model = nncf.quantize(torch_model, calibration_dataset)
torch_quantized_model = nncf.quantize(torch_model, calibration_dataset, subset_size=subset_size)

###############################################################################
# Benchmark performance, calculate compression rate and validate accuracy
Expand Down
18 changes: 17 additions & 1 deletion nncf/common/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from functools import partial
from typing import List, Set
from typing import List, Set, Tuple, Union

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
Expand Down Expand Up @@ -114,3 +114,19 @@ def get_number_of_quantized_ops(
else:
nodes_to_see.extend(graph.get_next_nodes(node))
return len(quantized_ops)


def get_reduction_axes(
channel_axes: Union[List[int], Tuple[int, ...]], shape: Union[List[int], Tuple[int, ...]]
) -> Tuple[int, ...]:
"""
Returns filtered reduction axes without axes that correspond to channels.
:param channel_axes: Channel axes.
:param shape: Shape that need to be filtered.
:return: Reduction axes.
"""
reduction_axes = list(range(len(shape)))
for channel_axis in sorted(channel_axes, reverse=True):
del reduction_axes[channel_axis]
return tuple(reduction_axes)
53 changes: 52 additions & 1 deletion nncf/common/quantization/initialization/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple, Union

from nncf.common.graph.utils import get_reduction_axes
from nncf.common.initialization.dataloader import NNCFDataLoader
from nncf.common.quantization.structs import QuantizationScheme
from nncf.common.quantization.structs import QuantizerGroup
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.config.schemata.defaults import NUM_INIT_SAMPLES
from nncf.experimental.common.tensor_statistics.collectors import AggregationAxes


class RangeInitConfig:
Expand Down Expand Up @@ -204,3 +207,51 @@ def use_means_of_mins(self) -> bool:
@property
def use_means_of_maxs(self) -> bool:
return not self._is_weights and not self._is_per_channel

def _get_reduction_axes(
self,
shape_to_reduce: Union[Tuple[int, ...], List[int]],
quantization_axes: Union[Tuple[int, ...], List[int]],
aggregation_axes: Union[Tuple[int, ...], List[int]],
):
"""
Returns axes for a reducer regarding aggregation axes. As aggregator takes axes counting from stacked tensors,
from these axes only tensor related axes should be used for reducer.
:param shape_to_reduce: Shape of a reduced tensor.
:param quantization_axes: Axes of quantization.
:param aggregation_axes: Axes of aggregator which is applied onto reduced tensor.
:return: Axes for reducer.
"""
axes_to_keep = set(el - 1 for el in aggregation_axes if el != 0)
axes_to_keep.update(quantization_axes)
return get_reduction_axes(axes_to_keep, shape_to_reduce)

def _get_aggregation_axes(self, batchwise_statistics: bool) -> Tuple[int, ...]:
"""
Returns axes for aggregator.
:param batchwise_statistics: Determines whether quantizer statistics should be calculated
for each item of the batch or for the entire batch.
:return Tuple[int]: Aggregation axes.
"""
return (0, 1) if batchwise_statistics else (0,)

def get_reduction_aggregation_axes(
self,
shape_to_reduce: Union[Tuple[int, ...], List[int]],
quantization_axes: Union[Tuple[int, ...], List[int]],
batchwise_statistics: bool,
) -> Tuple[ReductionAxes, AggregationAxes]:
"""
Calculates the reduction axes, aggregation axes for the tensor.
:param shape_to_reduce: Shape of the tensor.
:param quantization_axes: Quantization axes if per-channel quantization.
:param batchwise_statistics: Determines whether quantizer statistics should be calculated
for each item of the batch or for the entire batch.
:return: Reduction axes and aggregation axes.
"""
aggregation_axes = self._get_aggregation_axes(batchwise_statistics)
reduction_axes = self._get_reduction_axes(shape_to_reduce, quantization_axes, aggregation_axes)
return reduction_axes, aggregation_axes
41 changes: 27 additions & 14 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
from abc import ABC
from abc import abstractmethod
from itertools import islice
from typing import Any, Dict, TypeVar
from typing import Any, Dict, Optional, TypeVar

import nncf
from nncf.common import factory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging.logger import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
Expand All @@ -25,6 +26,13 @@
TensorType = TypeVar("TensorType")
TModel = TypeVar("TModel")

EMPTY_DATASET_ERROR = (
"Calibration dataset must not be empty. Please provide calibration dataset with at least one sample."
)
ITERATIONS_NUMBER_WARNING = (
"The number of iterations for statistics collection is bigger than the length of the dataset."
)


class StatisticsAggregator(ABC):
"""
Expand All @@ -36,6 +44,20 @@ def __init__(self, dataset: Dataset):
self.stat_subset_size = None
self.statistic_points = StatisticPointsContainer()

def _get_iterations_number(self) -> Optional[int]:
"""
Returns number of iterations, output number is less than min(self.stat_subset_size, dataset_length).
:return: Number of iterations for statistics collection.
"""
dataset_length = self.dataset.get_length()
if dataset_length and self.stat_subset_size:
if self.stat_subset_size > dataset_length:
nncf_logger.warning(ITERATIONS_NUMBER_WARNING)
return dataset_length
return self.stat_subset_size
return dataset_length or self.stat_subset_size

def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
"""
Collects statistics for registered StatisticPoints.
Expand All @@ -46,34 +68,25 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
"""
if not self.statistic_points:
return

model_transformer = factory.ModelTransformerFactory.create(model)

merged_statistics = self._get_merged_statistic_points(self.statistic_points, model, graph)
transformation_layout = self._get_transformation_layout_extra_outputs(merged_statistics)
model_with_outputs = model_transformer.transform(transformation_layout)
engine = factory.EngineFactory.create(model_with_outputs)

dataset_length = self.dataset.get_length()
total = (
min(dataset_length or self.stat_subset_size, self.stat_subset_size)
if self.stat_subset_size is not None
else None
)
iterations_number = self._get_iterations_number()
empty_statistics = True
for input_data in track(
islice(self.dataset.get_inference_data(), self.stat_subset_size),
total=total,
islice(self.dataset.get_inference_data(), iterations_number),
total=self.stat_subset_size,
description="Statistics collection",
):
outputs = engine.infer(input_data)
processed_outputs = self._process_outputs(outputs)
self._register_statistics(processed_outputs, merged_statistics)
empty_statistics = False
if empty_statistics:
raise nncf.ValidationError(
"Calibration dataset must not be empty. Please provide calibration dataset with at least one sample."
)
raise nncf.ValidationError(EMPTY_DATASET_ERROR)

def register_statistic_points(self, statistic_points: StatisticPointsContainer) -> None:
"""
Expand Down
11 changes: 11 additions & 0 deletions nncf/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ def get_length(self) -> Optional[int]:
return self._data_source.__len__()
return None

def get_batch_size(self) -> Optional[int]:
"""
Tries to fetch batch size of the underlying dataset.
:return: The value of batch_size or _batch_size attributes of the data_source if exist, and None otherwise.
"""
if hasattr(self._data_source, "batch_size"): # Torch dataloader
return self._data_source.batch_size
if hasattr(self._data_source, "_batch_size"): # TF dataloader
return self._data_source._batch_size
return None


class DataProvider(Generic[DataItem, ModelInput]):
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from nncf.experimental.tensor import Tensor
from nncf.experimental.tensor import TensorDataType
from nncf.experimental.tensor.functions import numeric as fns


def mean_per_channel(x: Tensor, axis: int) -> Tensor:
def mean_per_channel(x: Tensor, axis: int, dtype: Optional[TensorDataType] = None) -> Tensor:
"""
Computes the mean of elements across given channel dimension of Tensor.
:param x: Tensor to reduce.
:param axis: The channel dimensions to reduce.
:param dtype: Type to use in computing the mean.
:return: Reduced Tensor.
"""
if len(x.shape) < 3:
return fns.mean(x, axis=0)
return fns.mean(x, axis=0, dtype=dtype)

pos_axis = axis + x.ndim if axis < 0 else axis
if pos_axis < 0 or pos_axis >= x.ndim:
raise ValueError(f"axis {axis} is out of bounds for array of dimension {x.ndim}")
axis = tuple(i for i in range(x.ndim) if i != pos_axis)
return fns.mean(x, axis=axis)
return fns.mean(x, axis=axis, dtype=dtype)
7 changes: 5 additions & 2 deletions nncf/experimental/tensor/functions/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,16 +355,19 @@ def moveaxis(a: Tensor, source: Union[int, Tuple[int, ...]], destination: Union[

@functools.singledispatch
@tensor_guard
def mean(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Tensor:
def mean(
a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, dtype: TensorDataType = None
) -> Tensor:
"""
Compute the arithmetic mean along the specified axis.
:param a: Array containing numbers whose mean is desired.
:param axis: Axis or axes along which the means are computed.
:param keepdims: Destination positions for each of the original axes. These must also be unique.
:param dtype: Type to use in computing the mean.
:return: Array with moved axes.
"""
return Tensor(mean(a.data, axis, keepdims))
return Tensor(mean(a.data, axis, keepdims, dtype))


@functools.singledispatch
Expand Down
10 changes: 8 additions & 2 deletions nncf/experimental/tensor/functions/numpy_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,14 @@ def _(a: np.ndarray, source: Union[int, Tuple[int, ...]], destination: Union[int


@register_numpy_types(numeric.mean)
def _(a: Union[np.ndarray, np.generic], axis: Union[int, Tuple[int, ...]] = None, keepdims: bool = False) -> np.ndarray:
return np.array(np.mean(a, axis=axis, keepdims=keepdims))
def _(
a: Union[np.ndarray, np.generic],
axis: Union[int, Tuple[int, ...]] = None,
keepdims: bool = False,
dtype: Optional[TensorDataType] = None,
) -> np.ndarray:
dtype = DTYPE_MAP[dtype] if dtype else None
return np.array(np.mean(a, axis=axis, keepdims=keepdims, dtype=dtype))


@register_numpy_types(numeric.round)
Expand Down
10 changes: 8 additions & 2 deletions nncf/experimental/tensor/functions/torch_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,14 @@ def _(a: torch.Tensor, source: Union[int, Tuple[int, ...]], destination: Union[i


@numeric.mean.register(torch.Tensor)
def _(a: torch.Tensor, axis: Union[int, Tuple[int, ...]] = None, keepdims: bool = False) -> torch.Tensor:
return torch.mean(a, dim=axis, keepdim=keepdims)
def _(
a: torch.Tensor,
axis: Union[int, Tuple[int, ...]] = None,
keepdims: bool = False,
dtype: Optional[TensorDataType] = None,
) -> torch.Tensor:
dtype = DTYPE_MAP[dtype] if dtype else None
return torch.mean(a, dim=axis, keepdim=keepdims, dtype=dtype)


@numeric.round.register(torch.Tensor)
Expand Down
9 changes: 9 additions & 0 deletions nncf/onnx/graph/metatypes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,12 @@
onnx_metatypes.ONNXQuantizeLinearMetatype,
onnx_metatypes.ONNXDequantizeLinearMetatype,
]

# These metatypes mix outputs for different samples into one axis.
# If reducers and aggregators collect statistics at the output of the following operations,
# assuming that 0-axis is batch axis, they get only 1 value instead of batch_size values.
# It could lead to inaccurate/incorrect statistics result.
OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS = [
onnx_metatypes.ONNXROIAlignMetatype,
onnx_metatypes.ONNXEmbeddingMetatype,
]
8 changes: 4 additions & 4 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class ONNXGemmMetatype(ONNXOpWithWeightsMetatype):
name = "GemmOp"
op_names = ["Gemm"]
hw_config_names = [HWConfigOpName.MATMUL]
weight_channel_axis = -1
weight_channel_axis = -1 # For port_id=1
weight_port_ids = None
bias_port_id = 2
possible_weight_ports = [0, 1]
Expand All @@ -142,7 +142,7 @@ class ONNXMatMulMetatype(ONNXOpMetatype):
name = "MatMulOp"
op_names = ["MatMul"]
hw_config_names = [HWConfigOpName.MATMUL]
weight_channel_axis = -1
weight_channel_axis = -1 # For port_id=1
weight_port_ids = None
bias_port_id = 2
possible_weight_ports = [0, 1]
Expand Down Expand Up @@ -463,8 +463,8 @@ class ONNXScatterNDMetatype(ONNXOpMetatype):


@ONNX_OPERATION_METATYPES.register()
class ONNXRoiAlignMetatype(ONNXOpMetatype):
name = "RoiAlignOp"
class ONNXROIAlignMetatype(ONNXOpMetatype):
name = "ROIAlignOp"
op_names = ["RoiAlign"]


Expand Down
Loading

0 comments on commit b7ba5ad

Please sign in to comment.