Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add int4 and int8 checks #2617

Merged
merged 5 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion tests/post_training/data/wc_reference_data.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
tinyllama_data_free_backend_OV:
metric_value: 0.72057
num_int4: 228
num_int8: 84
tinyllama_data_aware_backend_OV:
metric_value: 0.83084
DaniAffCH marked this conversation as resolved.
Show resolved Hide resolved
num_int4: 184
num_int8: 128
tinyllama_data_aware_awq_backend_OV:
metric_value: 0.81237
num_int4: 184
num_int8: 128
tinyllama_data_aware_awq_stateful_backend_OV:
metric_value: 0.81237
metric_value: 0.81237
num_int4: 184
num_int8: 128
17 changes: 13 additions & 4 deletions tests/post_training/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def fill(self, stdout: str) -> None:
"""


@dataclass
class NumCompressNodes:
num_fq_nodes: Optional[int] = None
num_int8: Optional[int] = None
num_int4: Optional[int] = None


@dataclass
class PTQTimeStats(StatsFromOutput):
"""
Expand Down Expand Up @@ -130,12 +137,12 @@ class RunInfo:
metric_name: Optional[str] = None
metric_value: Optional[float] = None
metric_diff: Optional[float] = None
num_fq_nodes: Optional[float] = None
compression_memory_usage: Optional[int] = None
status: Optional[str] = None
fps: Optional[float] = None
time_total: Optional[float] = None
time_compression: Optional[float] = None
num_compress_nodes: Optional[NumCompressNodes] = None
stats_from_output = StatsFromOutput()

@staticmethod
Expand All @@ -157,7 +164,9 @@ def get_result_dict(self):
"Metric name": self.metric_name,
"Metric value": self.metric_value,
"Metric diff": self.metric_diff,
"Num FQ": self.num_fq_nodes,
"Num FQ": self.num_compress_nodes.num_fq_nodes,
"Num int4": self.num_compress_nodes.num_int4,
"Num int8": self.num_compress_nodes.num_int8,
"RAM MiB": self.format_memory_usage(self.compression_memory_usage),
"Compr. time": self.format_time(self.time_compression),
**self.stats_from_output.get_stats(),
Expand Down Expand Up @@ -210,7 +219,7 @@ def __init__(
self.dummy_tensor = None
self.input_size = None

self.run_info = RunInfo(model=reported_name, backend=self.backend)
self.run_info = RunInfo(model=reported_name, backend=self.backend, num_compress_nodes=NumCompressNodes())

@abstractmethod
def prepare_preprocessor(self) -> None:
Expand Down Expand Up @@ -383,7 +392,7 @@ def get_num_compressed(self) -> None:
if node_type == "FakeQuantize":
num_fq += 1

self.run_info.num_fq_nodes = num_fq
self.run_info.num_compress_nodes.num_fq_nodes = num_fq

def run_bench(self) -> None:
"""
Expand Down
32 changes: 31 additions & 1 deletion tests/post_training/pipelines/lm_weight_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,21 @@ def save_compressed_model(self) -> None:
self.model_hf._save_config(self.output_model_dir)

def get_num_compressed(self) -> None:
pass
"""
Get number of the i8, u8, i4, u4 ops in the compressed IR.
"""
num_int8 = 0
num_int4 = 0

for node in self.model.get_ops():
for i in range(node.get_output_size()):
if node.get_output_element_type(i).get_type_name() in ["i8", "u8"]:
num_int8 += 1
if node.get_output_element_type(i).get_type_name() in ["i4", "u4"]:
num_int4 += 1

self.run_info.num_compress_nodes.num_int8 = num_int8
self.run_info.num_compress_nodes.num_int4 = num_int4

def run_bench(self) -> None:
pass
Expand Down Expand Up @@ -219,3 +233,19 @@ def _validate(self):
similarity = all_metrics["similarity"][0]
self.run_info.metric_name = "Similarity"
self.run_info.metric_value = round(similarity, 5)

num_int4_reference = self.reference_data.get("num_int4")
num_int8_reference = self.reference_data.get("num_int8")

num_int4_value = self.run_info.num_compress_nodes.num_int4
num_int8_value = self.run_info.num_compress_nodes.num_int8

if num_int4_reference != num_int4_value:
status_msg = f"Regression: The number of int4 ops is different \
than reference {num_int4_reference} != {num_int4_value}"
raise ValueError(status_msg)

if num_int8_reference != num_int8_value:
status_msg = f"Regression: The number of int8 ops is different \
than reference {num_int8_reference} != {num_int8_value}"
raise ValueError(status_msg)
Loading