Skip to content

Commit

Permalink
XNNPACKQuantizer
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jan 9, 2025
1 parent cb3e426 commit d326e48
Show file tree
Hide file tree
Showing 6 changed files with 550 additions and 468 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ repos:
rev: v0.3.7
hooks:
- id: ruff
args: [--fix]

- repo: https://github.com/igorshubovych/markdownlint-cli
rev: v0.33.0
Expand Down
119 changes: 119 additions & 0 deletions tests/torch/fx/performance_check/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import subprocess
from abc import ABC
from abc import abstractmethod
from enum import Enum
from pathlib import Path
from time import time
from typing import Any, List, Tuple

import openvino as ov
import torch
import torch.fx

from nncf.torch.dynamic_graph.patch_pytorch import disable_patching
from tests.torch.fx.performance_check.model_scope import ModelConfig


class BenchmarkInterface(ABC):
@abstractmethod
def __call__(self, model: Any, model_config: ModelConfig, model_path: Path) -> Any:
"""
Benchmarks given model.
"""

@abstractmethod
def name(self) -> str:
"""
Name of the Benchmarking stage.
"""


class LatencyBenchmark(BenchmarkInterface):
def __call__(self, model: Any, model_config: ModelConfig, model_path: Path) -> Any:
with disable_patching():
with torch.no_grad():
example_inputs = model_config.model_builder.get_example_inputs()
if isinstance(model, ov.Model):
return measure_time_ov(model, example_inputs, model_config.num_iters)
return measure_time(model, example_inputs, model_config.num_iters)

def name(self) -> str:
return "Latency, msec"


class BenchmarkAppMode(Enum):
SYNC = "sync"
ASYNC = "async"


class BenchmarkAppFPS(BenchmarkInterface):
def __init__(self, mode: BenchmarkAppMode) -> None:
self.mode = mode

def __call__(self, model: Any, model_config: ModelConfig, model_path: Path) -> Any:
fps, latency = benchmark_performance(
model_path=model_path,
input_shape=model_config.model_builder.get_input_sizes(),
mode=self.mode.value,
num_iters=model_config.num_iters,
)
return fps, latency

def name(self) -> str:
return f"Benchmark app: {self.mode.value} (FPS, latency, msec))"


def measure_time(model, example_inputs, num_iters=500):
with torch.no_grad():
model(*example_inputs)
total_time = 0
for _ in range(num_iters):
start_time = time()
model(*example_inputs)
total_time += time() - start_time
average_time = (total_time / num_iters) * 1000
return average_time


def measure_time_ov(model, example_inputs, num_iters=500):
ie = ov.Core()
compiled_model = ie.compile_model(model, "CPU")
infer_request = compiled_model.create_infer_request()
infer_request.infer(example_inputs)
total_time = 0
for _ in range(num_iters):
start_time = time()
infer_request.infer(example_inputs)
total_time += time() - start_time
average_time = (total_time / num_iters) * 1000
return average_time


def benchmark_performance(model_path: str, input_shape: List[int], mode: str, num_iters: int) -> Tuple[float, float]:
if mode == "sync":
exec_mode = "latency"
else:
exec_mode = "throughput"

command = f"benchmark_app -m {model_path} -d CPU -hint {exec_mode} -niter {num_iters}"
command += f' -shape "[{",".join(str(s) for s in input_shape)}]"'
cmd_output = subprocess.check_output(command, shell=True) # nosec

match = re.search(r"Throughput\: (.+?) FPS", str(cmd_output))
fps = float(match.group(1))

match = re.search(r"Average\: (.+?) ms", str(cmd_output))
latency = float(match.group(1))
return fps, latency
100 changes: 100 additions & 0 deletions tests/torch/fx/performance_check/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import abstractmethod
from pathlib import Path
from typing import Any

import openvino as ov
import openvino.torch # noqa
import torch
import torch.fx
from torch._export import capture_pre_autograd_graph

from nncf.torch.dynamic_graph.patch_pytorch import disable_patching
from tests.torch.fx.performance_check.model_scope import ModelConfig


class ExportInterface(ABC):
@abstractmethod
def __call__(self, model: Any, model_config: ModelConfig, path_to_save_model: Path) -> Any:
"""
Converts passed torch.nn.Module to the target representation
"""

@abstractmethod
def name(self) -> str:
"""
Return name of the export before quantization stage.
"""


class NoExport(ExportInterface):
def __call__(self, model: Any, model_config: ModelConfig, path_to_save_model: Path) -> Any:
return model

def name(self) -> str:
return "No export"


class CapturePreAutogradGraphExport(ExportInterface):
def __call__(self, model: Any, model_config: ModelConfig, path_to_save_model: Path) -> torch.fx.GraphModule:
with disable_patching():
with torch.no_grad():
return capture_pre_autograd_graph(model, args=model_config.model_builder.get_example_inputs())

def name(self) -> str:
return "capture_pre_autograd_graph"


class TorchExport(ExportInterface):
def __call__(self, model: Any, model_config: ModelConfig, path_to_save_model: Path) -> Any:
with disable_patching():
with torch.no_grad():
return torch.export.export(
model, args=model_config.model_builder.get_example_inputs(), strict=model_config.torch_export_strict
).module()

def name(self) -> str:
return "torch.export.export"


class OpenvinoIRExport(ExportInterface):
def __call__(self, model: Any, model_config: ModelConfig, path_to_save_model: Path) -> Any:
with disable_patching():
with torch.no_grad():
example_inputs = model_config.model_builder.get_example_inputs()
export_inputs = example_inputs[0] if isinstance(example_inputs[0], tuple) else example_inputs
input_sizes = model_config.model_builder.get_input_sizes()
ex_model = torch.export.export(model, export_inputs)
ov_model = ov.convert_model(ex_model, example_input=example_inputs[0], input=input_sizes)
ov.serialize(ov_model, path_to_save_model)
return ov_model

def name(self) -> str:
return "Export to openvino IR"


class TorchCompileExport(ExportInterface):
def __call__(self, model: Any, model_config: ModelConfig, path_to_save_model: Path):
return torch.compile(model)

def name(self) -> str:
return "torch.compile(...)"


class TorchCompileOVExport(ExportInterface):
def __call__(self, model: Any, model_config: ModelConfig, path_to_save_model: Path):
return torch.compile(model, backend="openvino")

def name(self) -> str:
return "torch.compile(..., backend='openvino')"
Loading

0 comments on commit d326e48

Please sign in to comment.