Skip to content

Commit

Permalink
WIP resnet18 accuracy check
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed May 29, 2024
1 parent 857a255 commit 8934a06
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 7 deletions.
39 changes: 34 additions & 5 deletions examples/quantization_aware_training/torch/resnet18/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
# limitations under the License.

import os

os.environ["TORCHINDUCTOR_FREEZING"] = "1"


import re
import subprocess
import warnings
Expand All @@ -19,6 +23,7 @@

import openvino as ov
import torch
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
import torch.nn as nn
import torch.nn.parallel
import torch.optim
Expand All @@ -28,6 +33,10 @@
import torchvision.models as models
import torchvision.transforms as transforms
from fastdownload import FastDownload
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.jit import TracerWarning

import nncf
Expand Down Expand Up @@ -102,7 +111,7 @@ def validate(val_loader: torch.utils.data.DataLoader, model: torch.nn.Module, de
top1_sum = 0.0

# Switch to evaluate mode.
model.eval()
# model.eval()

with torch.no_grad():
for images, target in track(val_loader, total=len(val_loader), description="Validation:"):
Expand Down Expand Up @@ -230,7 +239,7 @@ def get_model_size(ir_path: str, m_type: str = "Mb") -> float:

def main():
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
print(f"Using {device} device")

###############################################################################
Expand All @@ -253,11 +262,31 @@ def transform_fn(data_item):
# Step 2: Quantize model
print(os.linesep + "[Step 2] Quantize model")

quantized_model = nncf.quantize(model, quantization_dataset)
acc1_int8_init = validate(val_loader, quantized_model, device)
with torch.no_grad():
example_inputs = (torch.ones((1, 3, IMAGE_SIZE, IMAGE_SIZE)),)
exported_model = capture_pre_autograd_graph(model, example_inputs)

print(f"Accuracy@1 of initialized INT8 model: {acc1_int8_init:.3f}")
NNCF_TORCH_FX = False

if NNCF_TORCH_FX:
quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())

prepared_model = prepare_pt2e(exported_model, quantizer)
from itertools import islice

from tqdm import tqdm

for data in tqdm(islice(quantization_dataset.get_inference_data(), 3)):
prepared_model(data)
quantized_model = convert_pt2e(prepared_model)
else:
quantized_model = nncf.quantize(exported_model, quantization_dataset)

quantized_model = torch.compile(quantized_model)
acc1_int8_init = validate(val_loader, quantized_model, device)
print(f"Accuracy@1 of initialized INT8 model: {acc1_int8_init:.3f}")
return
###############################################################################
# Step 3: Fine tune quantized model
print(os.linesep + "[Step 3] Fine tune quantized model")
Expand Down
1 change: 1 addition & 0 deletions nncf/experimental/torch_fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self, model: torch.fx.GraphModule):
super().__init__(model)

self._command_transformation_ordered_pairs = [
# TODO: Move the module insertion command to a transformation
(FXApplyTransformationCommand, self._apply_transformation),
(FXModuleInsertionCommand, self._apply_module_insertion),
]
Expand Down
1 change: 1 addition & 0 deletions nncf/experimental/torch_fx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def quantize_impl(
raise ValueError(f"mode={mode} is not supported")

copied_model = deepcopy(model)
# copied_model = model

quantization_algorithm = PostTrainingQuantization(
preset=preset,
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand
from nncf.experimental.torch_fx.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.experimental.torch_fx.transformations import fake_quantize_insertion_tranformation_builder
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
Expand All @@ -41,7 +42,6 @@
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.hardware.config import PTHWConfig
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.quantization.layers import QUANTIZATION_MODULES
from nncf.torch.quantization.layers import AsymmetricQuantizer
from nncf.torch.quantization.layers import BaseQuantizer
Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/graph/pattern_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
)

ARITHMETIC_OPERATIONS = {
GraphPattern.METATYPE_ATTR: ["__iadd__", "__add__", "__mul__", "__rmul__", "__truediv__"],
GraphPattern.METATYPE_ATTR: ["__iadd__", "__add__", "__mul__", "__rmul__", "__truediv__", "add_"],
GraphPattern.LABEL_ATTR: "ARITHMETIC",
}

Expand Down

0 comments on commit 8934a06

Please sign in to comment.