diff --git a/examples/quantization_aware_training/torch/anomalib/main.py b/examples/quantization_aware_training/torch/anomalib/main.py index f21dbdd642c..20fb234a90d 100644 --- a/examples/quantization_aware_training/torch/anomalib/main.py +++ b/examples/quantization_aware_training/torch/anomalib/main.py @@ -12,13 +12,17 @@ import os import re import subprocess +import tarfile from copy import deepcopy from pathlib import Path from typing import List +from urllib.request import urlretrieve import torch from anomalib import TaskType from anomalib.data import MVTec +from anomalib.data.image import mvtec +from anomalib.data.utils import download from anomalib.deploy import ExportType from anomalib.engine import Engine from anomalib.models import Stfpm @@ -27,14 +31,39 @@ HOME_PATH = Path.home() DATASET_PATH = HOME_PATH / ".cache/nncf/datasets/mvtec" -CHECKPOINT_PATH = HOME_PATH / ".cache/nncf/models/stfpm_mvtec" +CHECKPOINT_PATH = HOME_PATH / ".cache/nncf/models/anomalib" ROOT = Path(__file__).parent.resolve() FP32_RESULTS_ROOT = ROOT / "fp32" INT8_RESULTS_ROOT = ROOT / "int8" -CHECKPOINT_URL = "https://huggingface.co/alexsu52/stfpm_mvtec_capsule/resolve/main/qat/model.ckpt" +CHECKPOINT_URL = "https://storage.openvinotoolkit.org/repositories/nncf/examples/torch/anomalib/stfpm_mvtec.ckpt" USE_PRETRAINED = True +def download_and_extract(root: Path, info: download.DownloadInfo) -> None: + root.mkdir(parents=True, exist_ok=True) + downloaded_file_path = root / info.url.split("/")[-1] + print(f"Downloading the {info.name} dataset.") + with download.DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=info.name) as progress_bar: + urlretrieve( + url=f"{info.url}", + filename=downloaded_file_path, + reporthook=progress_bar.update_to, + ) + print("Checking the hash of the downloaded file.") + download.check_hash(downloaded_file_path, info.hashsum) + print(f"Extracting the {info.name} dataset.") + with tarfile.open(downloaded_file_path) as tar_file: + tar_file.extractall(root) + print("Cleaning up files.") + downloaded_file_path.unlink() + + +def create_dataset(root: Path) -> MVTec: + if not root.exists(): + download_and_extract(root, mvtec.DOWNLOAD_INFO) + return MVTec(root) + + def run_benchmark(model_path: Path, shape: List[int]) -> float: command = f"benchmark_app -m {model_path} -d CPU -api async -t 15" command += f' -shape "[{",".join(str(x) for x in shape)}]"' @@ -65,14 +94,14 @@ def main(): print(os.linesep + "[Step 1] Prepare the model and dataset") model = Stfpm() - datamodule = MVTec(root=DATASET_PATH) + datamodule = create_dataset(root=DATASET_PATH) # Create an engine for the original model engine = Engine(task=TaskType.SEGMENTATION, default_root_dir=FP32_RESULTS_ROOT) if USE_PRETRAINED: # Load the pretrained checkpoint CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True) - ckpt_path = CHECKPOINT_PATH / "model.ckpt" + ckpt_path = CHECKPOINT_PATH / "stfpm_mvtec.ckpt" torch.hub.download_url_to_file(CHECKPOINT_URL, ckpt_path) else: # (Optional) Train the model from scratch