Skip to content

Commit

Permalink
Merge branch 'develop' into ad/pt_graph_const_minmax
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Apr 16, 2024
2 parents 95dc198 + 47342f2 commit a7878a5
Show file tree
Hide file tree
Showing 30 changed files with 420 additions and 159 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ jobs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- uses: AlexanderDokuchaev/md-dead-link-check@v0.6
- uses: AlexanderDokuchaev/md-dead-link-check@v0.8
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit-linters.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ jobs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- uses: AlexanderDokuchaev/md-dead-link-check@v0.6
- uses: AlexanderDokuchaev/md-dead-link-check@v0.8
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ pip install nncf[torch]

Other viable options besides `[torch]` are `[tf]`, `[onnx]` and `[openvino]`.

> [!WARNING]
> The way to install the module package with the extra dependency like `pip install nncf[torch]` is deprecated and will be removed in a future release.
> Instead, it is recommended to install additional dependencies separately using the pip install command (e.g., `pip install torch`) or by explicitly specifying the dependency in your requirements file.
NNCF is also available via [conda](https://anaconda.org/conda-forge/nncf):

```bash
Expand All @@ -383,7 +387,7 @@ conda install -c conda-forge nncf
### System requirements

- Ubuntu\* 18.04 or later (64-bit)
- Python\* 3.7 or later
- Python\* 3.8 or later
- Supported frameworks:
- PyTorch\* >=2.1, <2.3
- TensorFlow\* >=2.8.4, <=2.12.1
Expand Down
24 changes: 1 addition & 23 deletions docs/Installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,6 @@ NNCF can be installed as a regular PyPI package via pip:
pip install nncf
```

If you want to install both NNCF and the supported PyTorch version in one line, you can do this by simply running:

```bash
pip install nncf[torch]
```

Other viable options besides `[torch]` are `[tf]`, `[onnx]` and `[openvino]`.

## As a package built from a checked-out repository

Install the package and its dependencies by running the following command in the repository root directory:
Expand All @@ -28,20 +20,6 @@ Install the package and its dependencies by running the following command in the
pip install .
```

Use the same `pip install` syntax as above to install NNCF along with the backend package version in one go:

```bash
pip install .[<BACKEND>]
```

List of supported backends: `torch`, `tf`, `onnx` and `openvino`.

For development purposes install extra packages by

```bash
pip install .[dev,tests]
```

_NB_: For launching example scripts in this repository, we recommend setting the `PYTHONPATH` variable to the root of the checked-out repository once the installation is completed.

NNCF is also available via [conda](https://anaconda.org/conda-forge/nncf):
Expand All @@ -65,7 +43,7 @@ as well as the supported versions of Python:

| NNCF | OpenVINO | PyTorch | ONNX | TensorFlow | Python |
|-----------|------------|----------|----------|------------|--------|
| `develop` | `2024.4.0` | `2.2.1` | `1.13.1` | `2.12.0` | `3.8` |
| `develop` | `2024.4.0` | `2.2.1` | `1.16.0` | `2.12.0` | `3.8` |
| `2.9.0` | `2024.4.0` | `2.1.2` | `1.13.1` | `2.12.0` | `3.8` |
| `2.8.1` | `2023.3.0` | `2.1.2` | `1.13.1` | `2.12.0` | `3.8` |
| `2.8.0` | `2023.3.0` | `2.1.2` | `1.13.1` | `2.12.0` | `3.8` |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,20 +195,17 @@ def validation_ac(
preset=nncf.QuantizationPreset.MIXED,
ignored_scope=nncf.IgnoredScope(
types=["Mul", "Sub", "Sigmoid"], # ignore operations
names=[
"/model.22/dfl/conv/Conv", # in the post-processing subgraph
"/model.22/Add",
"/model.22/Add_1",
"/model.22/Add_2",
"/model.22/Add_3",
"/model.22/Add_4",
"/model.22/Add_5",
"/model.22/Add_6",
"/model.22/Add_7",
"/model.22/Add_8",
"/model.22/Add_9",
"/model.22/Add_10",
"/model.22/Add_11",
subgraphs=[
nncf.Subgraph(
inputs=[
"/model.22/Concat_3",
"/model.22/Concat_6",
"/model.22/Concat_24",
"/model.22/Concat_5",
"/model.22/Concat_4",
],
outputs=["/model.22/Concat_29"],
)
],
),
)
Expand Down
20 changes: 6 additions & 14 deletions examples/post_training_quantization/openvino/yolov8/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,20 +122,12 @@ def transform_fn(data_item: Dict):
quantization_dataset,
preset=nncf.QuantizationPreset.MIXED,
ignored_scope=nncf.IgnoredScope(
types=["Multiply", "Subtract", "Sigmoid"], # ignore operations
names=[
"/model.22/dfl/conv/Conv", # in the post-processing subgraph
"/model.22/Add",
"/model.22/Add_1",
"/model.22/Add_2",
"/model.22/Add_3",
"/model.22/Add_4",
"/model.22/Add_5",
"/model.22/Add_6",
"/model.22/Add_7",
"/model.22/Add_8",
"/model.22/Add_9",
"/model.22/Add_10",
types=["Multiply", "Subtract", "Sigmoid"],
subgraphs=[
nncf.Subgraph(
inputs=["/model.22/Concat", "/model.22/Concat_1", "/model.22/Concat_2"],
outputs=["output0/sink_port_0"],
)
],
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,20 +186,17 @@ def validation_ac(
preset=nncf.QuantizationPreset.MIXED,
ignored_scope=nncf.IgnoredScope(
types=["Multiply", "Subtract", "Sigmoid"], # ignore operations
names=[
"/model.22/dfl/conv/Conv", # in the post-processing subgraph
"/model.22/Add",
"/model.22/Add_1",
"/model.22/Add_2",
"/model.22/Add_3",
"/model.22/Add_4",
"/model.22/Add_5",
"/model.22/Add_6",
"/model.22/Add_7",
"/model.22/Add_8",
"/model.22/Add_9",
"/model.22/Add_10",
"/model.22/Add_11",
subgraphs=[
nncf.Subgraph(
inputs=[
"/model.22/Concat_3",
"/model.22/Concat_6",
"/model.22/Concat_24",
"/model.22/Concat_5",
"/model.22/Concat_4",
],
outputs=["output0"],
)
],
),
)
Expand Down
2 changes: 2 additions & 0 deletions nncf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
from nncf.quantization.advanced_parameters import (
AdvancedAccuracyRestorerParameters as AdvancedAccuracyRestorerParameters,
)
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters as AdvancedBiasCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters as AdvancedQuantizationParameters
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters as AdvancedSmoothQuantParameters
from nncf.quantization.advanced_parameters import OverflowFix as OverflowFix
from nncf.scopes import IgnoredScope as IgnoredScope
from nncf.scopes import Subgraph as Subgraph
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/advanced_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class OverflowFix(StrEnum):


@api()
class FP8Type(Enum):
class FP8Type(StrEnum):
"""
Defines FP8 special types (https://arxiv.org/pdf/2209.05433.pdf).
Expand Down
38 changes: 21 additions & 17 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,25 +863,29 @@ def filter_func(point: StatisticPoint) -> bool:
group_statistics.append(statistics)

unified_values = self._backend_entity.unify_statistics(group_statistics)
for quantization_target_point in unified_scale_group:
qconfig = quantization_target_points[quantization_target_point]
q_group = QuantizerGroup.ACTIVATIONS
narrow_range = get_quantizer_narrow_range(qconfig, q_group)
if self._mode is not None:
destination_type = self._quantization_params[q_group].destination_type
parameters = calculate_convert_parameters(
unified_values, is_per_channel=qconfig.per_channel, destination_type=destination_type
)
command = self._backend_entity.create_convert_insertion_command(
quantization_target_point, parameters
)
else:
parameters = calculate_quantizer_parameters(unified_values, qconfig, q_group, narrow_range)
command = self._backend_entity.create_quantizer_insertion_command(
graph, quantization_target_point, qconfig, parameters
qconfigs = [quantization_target_points[qtp] for qtp in unified_scale_group]
if any(qconfigs[0] != qconfig for qconfig in qconfigs[1:]):
raise nncf.InternalError(f"QConfigs for unified scale group {unified_scale_group} are not equal")
qconfig = qconfigs[0]
q_group = QuantizerGroup.ACTIVATIONS
narrow_range = get_quantizer_narrow_range(qconfig, q_group)
if self._mode is not None:
destination_type = self._quantization_params[q_group].destination_type
parameters = calculate_convert_parameters(
unified_values, is_per_channel=qconfig.per_channel, destination_type=destination_type
)
for quantization_target_point in unified_scale_group:
transformation_layout.register(
self._backend_entity.create_convert_insertion_command(quantization_target_point, parameters)
)
continue
parameters = calculate_quantizer_parameters(unified_values, qconfig, q_group, narrow_range)
commands = self._backend_entity.create_unified_scales_quantizers_insertion_commands(
graph, unified_scale_group, qconfig, parameters
)
for command in commands:
transformation_layout.register(command)
unified_ops_list.add(quantization_target_point)
unified_ops_list.update(unified_scale_group)

for quantization_target_point, qconfig in quantization_target_points.items():
if quantization_target_point in unified_ops_list:
Expand Down
21 changes: 20 additions & 1 deletion nncf/quantization/algorithms/min_max/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,31 @@ def create_quantizer_insertion_command(
Returns backend-specific quantizer insertion command.
:param nncf_graph: NNCFGraph to get input/output shapes for the target point.
:param target_point: Target location for the correction.
:param target_point: Target location for the quantizer insertion.
:param quantizer_config: QuantizerConfig instance for the current layer.
:param parameters: FakeQuantizeParameters to calculate activation quantization parameters.
:return: Backend-specific TransformationCommand for the quantizer insertion operation.
"""

@staticmethod
@abstractmethod
def create_unified_scales_quantizers_insertion_commands(
nncf_graph: NNCFGraph,
target_points: List[TargetPoint],
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
) -> List[TransformationCommand]:
"""
Returns backend-specific unified scales quantizers insertion commands.
:param nncf_graph: NNCFGraph to get input/output shapes for the target point.
:param target_points: List of target locations for the quantizers insertion.
:param quantizer_config: QuantizerConfig instance for the current layer.
:param parameters: FakeQuantizeParameters to calculate activation quantization parameters.
:return: List of backend-specific TransformationCommands
for the quantizers with unified scales insertion operations.
"""

@staticmethod
@abstractmethod
def create_convert_insertion_command(
Expand Down
16 changes: 15 additions & 1 deletion nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def create_quantizer_insertion_command(
target_point: ONNXTargetPoint,
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
):
) -> ONNXQuantizerInsertionCommand:
tensor_type = np.int8 if np.any(parameters.input_low.data < 0) else np.uint8
is_weight = target_point.is_weight_target_point()
if is_weight:
Expand All @@ -131,6 +131,20 @@ def create_quantizer_insertion_command(
onnx_parameters = convert_fq_params_to_onnx_params(parameters, quantizer_config.num_bits, tensor_type, axis)
return ONNXQuantizerInsertionCommand(target_point, nncf_input_node_next_nodes, onnx_parameters)

@staticmethod
def create_unified_scales_quantizers_insertion_commands(
nncf_graph: NNCFGraph,
target_points: List[ONNXTargetPoint],
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
) -> List[ONNXQuantizerInsertionCommand]:
return [
ONNXMinMaxAlgoBackend.create_quantizer_insertion_command(
nncf_graph, target_point, quantizer_config, parameters
)
for target_point in target_points
]

@staticmethod
def create_convert_insertion_command(
target_point: ONNXTargetPoint,
Expand Down
9 changes: 9 additions & 0 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ def create_quantizer_insertion_command(
) -> OVQuantizerInsertionCommand:
return OVQuantizerInsertionCommand(target_point, parameters)

@staticmethod
def create_unified_scales_quantizers_insertion_commands(
nncf_graph: NNCFGraph,
target_points: List[OVTargetPoint],
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
) -> List[OVQuantizerInsertionCommand]:
return [OVQuantizerInsertionCommand(target_point, parameters) for target_point in target_points]

@staticmethod
def create_convert_insertion_command(
target_point: OVTargetPoint,
Expand Down
17 changes: 17 additions & 0 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.graph import PTTargetPoint
from nncf.torch.graph.transformations.command_creation import create_quantizer_insertion_command
from nncf.torch.graph.transformations.command_creation import create_shared_quantizer_insertion_command
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.hardware.config import PTHWConfig
Expand Down Expand Up @@ -296,6 +297,22 @@ def create_quantizer_insertion_command(
)
return create_quantizer_insertion_command(target_point, quantizer)

@staticmethod
def create_unified_scales_quantizers_insertion_commands(
nncf_graph: NNCFGraph,
target_points: List[PTTargetPoint],
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
) -> List[PTSharedFnInsertionCommand]:
_, scale_shape, _ = PTMinMaxAlgoBackend._get_input_scale_shape(
nncf_graph, target_points[0], quantizer_config.per_channel
)

quantizer = PTMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_points[0].target_type
)
return [create_shared_quantizer_insertion_command(target_points, quantizer)]

@staticmethod
def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[OperatorMetatype]:
types = []
Expand Down
19 changes: 18 additions & 1 deletion nncf/torch/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union
from typing import List, Union

from torch import Tensor

Expand Down Expand Up @@ -62,3 +62,20 @@ def create_quantizer_insertion_command(
compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER,
priority=TransformationPriority.QUANTIZATION_PRIORITY,
)


def create_shared_quantizer_insertion_command(
target_points: List[PTTargetPoint], quantizer: BaseQuantizer
) -> PTSharedFnInsertionCommand:
quantizers_ids = []
for target_point in target_points:
quantizers_ids.append(NonWeightQuantizerId(target_point.target_node_name, target_point.input_port_id))

storage_key = ";".join(str(quantizer_id) for quantizer_id in sorted(quantizers_ids, key=str))
return PTSharedFnInsertionCommand(
target_points=target_points,
fn=quantizer,
op_unique_name=storage_key,
compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER,
priority=TransformationPriority.QUANTIZATION_PRIORITY,
)
1 change: 1 addition & 0 deletions nncf/torch/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __eq__(self, other: "PTTargetPoint"):
isinstance(other, PTTargetPoint)
and self.target_type == other.target_type
and self.target_node_name == other.target_node_name
and self.input_port_id == other.input_port_id
)

def __str__(self):
Expand Down
Loading

0 comments on commit a7878a5

Please sign in to comment.