Skip to content

Commit

Permalink
fixed downloading and extraction of mvtec dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsu52 committed Apr 17, 2024
1 parent 4000422 commit 4522cbd
Showing 1 changed file with 33 additions and 4 deletions.
37 changes: 33 additions & 4 deletions examples/quantization_aware_training/torch/anomalib/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
import os
import re
import subprocess
import tarfile
from copy import deepcopy
from pathlib import Path
from typing import List
from urllib.request import urlretrieve

import torch
from anomalib import TaskType
from anomalib.data import MVTec
from anomalib.data.image import mvtec
from anomalib.data.utils import download
from anomalib.deploy import ExportType
from anomalib.engine import Engine
from anomalib.models import Stfpm
Expand All @@ -27,14 +31,39 @@

HOME_PATH = Path.home()
DATASET_PATH = HOME_PATH / ".cache/nncf/datasets/mvtec"
CHECKPOINT_PATH = HOME_PATH / ".cache/nncf/models/stfpm_mvtec"
CHECKPOINT_PATH = HOME_PATH / ".cache/nncf/models/anomalib"
ROOT = Path(__file__).parent.resolve()
FP32_RESULTS_ROOT = ROOT / "fp32"
INT8_RESULTS_ROOT = ROOT / "int8"
CHECKPOINT_URL = "https://huggingface.co/alexsu52/stfpm_mvtec_capsule/resolve/main/qat/model.ckpt"
CHECKPOINT_URL = "https://storage.openvinotoolkit.org/repositories/nncf/examples/torch/anomalib/stfpm_mvtec.ckpt"
USE_PRETRAINED = True


def download_and_extract(root: Path, info: download.DownloadInfo) -> None:
root.mkdir(parents=True, exist_ok=True)
downloaded_file_path = root / info.url.split("/")[-1]
print(f"Downloading the {info.name} dataset.")
with download.DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=info.name) as progress_bar:
urlretrieve(
url=f"{info.url}",
filename=downloaded_file_path,
reporthook=progress_bar.update_to,
)
print("Checking the hash of the downloaded file.")
download.check_hash(downloaded_file_path, info.hashsum)
print(f"Extracting the {info.name} dataset.")
with tarfile.open(downloaded_file_path) as tar_file:
tar_file.extractall(root)
print("Cleaning up files.")
downloaded_file_path.unlink()


def create_dataset(root: Path) -> MVTec:
if not root.exists():
download_and_extract(root, mvtec.DOWNLOAD_INFO)
return MVTec(root)


def run_benchmark(model_path: Path, shape: List[int]) -> float:
command = f"benchmark_app -m {model_path} -d CPU -api async -t 15"
command += f' -shape "[{",".join(str(x) for x in shape)}]"'
Expand Down Expand Up @@ -65,14 +94,14 @@ def main():
print(os.linesep + "[Step 1] Prepare the model and dataset")

model = Stfpm()
datamodule = MVTec(root=DATASET_PATH)
datamodule = create_dataset(root=DATASET_PATH)

# Create an engine for the original model
engine = Engine(task=TaskType.SEGMENTATION, default_root_dir=FP32_RESULTS_ROOT)
if USE_PRETRAINED:
# Load the pretrained checkpoint
CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)
ckpt_path = CHECKPOINT_PATH / "model.ckpt"
ckpt_path = CHECKPOINT_PATH / "stfpm_mvtec.ckpt"
torch.hub.download_url_to_file(CHECKPOINT_URL, ckpt_path)
else:
# (Optional) Train the model from scratch
Expand Down

0 comments on commit 4522cbd

Please sign in to comment.