Skip to content

Commit

Permalink
Copy model/nodes meta | comments
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jan 17, 2025
1 parent 88d27d7 commit 95e4bc7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
3 changes: 2 additions & 1 deletion nncf/experimental/quantization/quantizers/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ class Quantizer(ABC):
"""

@abstractmethod
def transform_prior_quantization(self, model: TModel):
def transform_prior_quantization(self, model: TModel) -> TModel:
"""
Transforms the given model in-place with the necessary modifications required prior to quantization.
:param model: Backend-specific model to be transformed.
:return: Transformed backend-specific model.
"""

@abstractmethod
Expand Down
25 changes: 14 additions & 11 deletions nncf/experimental/quantization/quantizers/torch_ao_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,26 @@ class TorchAOQuantizerAdapter(Quantizer):
def __init__(self, quantizer: TorchAOQuantizer):
self._quantizer = quantizer

def transform_prior_quantization(self, model: torch.fx.GraphModule):
self._quantizer.transform_for_annotation(model)
def transform_prior_quantization(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
return self._quantizer.transform_for_annotation(model)

def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
original_meta = model.meta

self._quantizer.annotate(model)
# Save model and nodes meta before the annotation
original_meta = model.meta.copy()
node_name_vs_meta = {}
with torch.no_grad():
for node in model.graph.nodes:
node_name_vs_meta[node.name] = node.meta.copy()

model = self._quantizer.annotate(model)
self._quantizer.validate(model)
quantizer_setup = self.get_quantizer_config_from_anotated_model(model)

# Remove quantization annotations from the original model
quantization_annotation_key = "quantization_annotation"
for n in model.graph.nodes:
if hasattr(n, "meta") and quantization_annotation_key in n.meta:
del n.meta[quantization_annotation_key]

# Recover original meta
model.meta = original_meta
for node in model.graph.nodes:
node.meta = node_name_vs_meta[node.name]

return quantizer_setup

@staticmethod
Expand Down

0 comments on commit 95e4bc7

Please sign in to comment.