diff --git a/examples/quantization_aware_training/torch/resnet18/main.py b/examples/quantization_aware_training/torch/resnet18/main.py index 7ab3b7af14a..83c07e71438 100644 --- a/examples/quantization_aware_training/torch/resnet18/main.py +++ b/examples/quantization_aware_training/torch/resnet18/main.py @@ -10,6 +10,10 @@ # limitations under the License. import os + +os.environ["TORCHINDUCTOR_FREEZING"] = "1" + + import re import subprocess import warnings @@ -19,6 +23,7 @@ import openvino as ov import torch +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq import torch.nn as nn import torch.nn.parallel import torch.optim @@ -28,6 +33,10 @@ import torchvision.models as models import torchvision.transforms as transforms from fastdownload import FastDownload +from torch._export import capture_pre_autograd_graph +from torch.ao.quantization.quantize_pt2e import convert_pt2e +from torch.ao.quantization.quantize_pt2e import prepare_pt2e +from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer from torch.jit import TracerWarning import nncf @@ -102,7 +111,7 @@ def validate(val_loader: torch.utils.data.DataLoader, model: torch.nn.Module, de top1_sum = 0.0 # Switch to evaluate mode. - model.eval() + # model.eval() with torch.no_grad(): for images, target in track(val_loader, total=len(val_loader), description="Validation:"): @@ -230,7 +239,7 @@ def get_model_size(ir_path: str, m_type: str = "Mb") -> float: def main(): torch.manual_seed(0) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cpu") print(f"Using {device} device") ############################################################################### @@ -253,11 +262,31 @@ def transform_fn(data_item): # Step 2: Quantize model print(os.linesep + "[Step 2] Quantize model") - quantized_model = nncf.quantize(model, quantization_dataset) - acc1_int8_init = validate(val_loader, quantized_model, device) + with torch.no_grad(): + example_inputs = (torch.ones((1, 3, IMAGE_SIZE, IMAGE_SIZE)),) + exported_model = capture_pre_autograd_graph(model, example_inputs) - print(f"Accuracy@1 of initialized INT8 model: {acc1_int8_init:.3f}") + NNCF_TORCH_FX = False + if NNCF_TORCH_FX: + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + + prepared_model = prepare_pt2e(exported_model, quantizer) + from itertools import islice + + from tqdm import tqdm + + for data in tqdm(islice(quantization_dataset.get_inference_data(), 3)): + prepared_model(data) + quantized_model = convert_pt2e(prepared_model) + else: + quantized_model = nncf.quantize(exported_model, quantization_dataset) + + quantized_model = torch.compile(quantized_model) + acc1_int8_init = validate(val_loader, quantized_model, device) + print(f"Accuracy@1 of initialized INT8 model: {acc1_int8_init:.3f}") + return ############################################################################### # Step 3: Fine tune quantized model print(os.linesep + "[Step 3] Fine tune quantized model") diff --git a/nncf/experimental/torch_fx/model_transformer.py b/nncf/experimental/torch_fx/model_transformer.py index 0a671dcffe0..a3ac2caacb8 100644 --- a/nncf/experimental/torch_fx/model_transformer.py +++ b/nncf/experimental/torch_fx/model_transformer.py @@ -84,6 +84,7 @@ def __init__(self, model: torch.fx.GraphModule): super().__init__(model) self._command_transformation_ordered_pairs = [ + # TODO: Move the module insertion command to a transformation (FXApplyTransformationCommand, self._apply_transformation), (FXModuleInsertionCommand, self._apply_module_insertion), ] diff --git a/nncf/experimental/torch_fx/quantization/quantize_model.py b/nncf/experimental/torch_fx/quantization/quantize_model.py index a9d980453b3..3a6302c12c3 100644 --- a/nncf/experimental/torch_fx/quantization/quantize_model.py +++ b/nncf/experimental/torch_fx/quantization/quantize_model.py @@ -55,6 +55,7 @@ def quantize_impl( raise ValueError(f"mode={mode} is not supported") copied_model = deepcopy(model) + # copied_model = model quantization_algorithm = PostTrainingQuantization( preset=preset, diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py index 2a363c8ee75..5e1bcd5692c 100644 --- a/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -27,6 +27,7 @@ from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand +from nncf.experimental.torch_fx.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT from nncf.experimental.torch_fx.transformations import fake_quantize_insertion_tranformation_builder from nncf.parameters import ModelType from nncf.parameters import TargetDevice @@ -41,7 +42,6 @@ from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.hardware.config import PTHWConfig from nncf.torch.nncf_network import NNCFNetwork -from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT from nncf.torch.quantization.layers import QUANTIZATION_MODULES from nncf.torch.quantization.layers import AsymmetricQuantizer from nncf.torch.quantization.layers import BaseQuantizer diff --git a/nncf/torch/graph/pattern_operations.py b/nncf/torch/graph/pattern_operations.py index 85e4bb0f6a9..b860560221d 100644 --- a/nncf/torch/graph/pattern_operations.py +++ b/nncf/torch/graph/pattern_operations.py @@ -72,7 +72,7 @@ ) ARITHMETIC_OPERATIONS = { - GraphPattern.METATYPE_ATTR: ["__iadd__", "__add__", "__mul__", "__rmul__", "__truediv__"], + GraphPattern.METATYPE_ATTR: ["__iadd__", "__add__", "__mul__", "__rmul__", "__truediv__", "add_"], GraphPattern.LABEL_ATTR: "ARITHMETIC", }