diff --git a/nncf/experimental/quantization/quantizers/quantizer.py b/nncf/experimental/quantization/quantizers/quantizer.py index 838b157123a..7307726756a 100644 --- a/nncf/experimental/quantization/quantizers/quantizer.py +++ b/nncf/experimental/quantization/quantizers/quantizer.py @@ -26,11 +26,12 @@ class Quantizer(ABC): """ @abstractmethod - def transform_prior_quantization(self, model: TModel): + def transform_prior_quantization(self, model: TModel) -> TModel: """ Transforms the given model in-place with the necessary modifications required prior to quantization. :param model: Backend-specific model to be transformed. + :return: Transformed backend-specific model. """ @abstractmethod diff --git a/nncf/experimental/quantization/quantizers/torch_ao_adapter.py b/nncf/experimental/quantization/quantizers/torch_ao_adapter.py index f148cf36ef6..2823c1afc50 100644 --- a/nncf/experimental/quantization/quantizers/torch_ao_adapter.py +++ b/nncf/experimental/quantization/quantizers/torch_ao_adapter.py @@ -44,23 +44,26 @@ class TorchAOQuantizerAdapter(Quantizer): def __init__(self, quantizer: TorchAOQuantizer): self._quantizer = quantizer - def transform_prior_quantization(self, model: torch.fx.GraphModule): - self._quantizer.transform_for_annotation(model) + def transform_prior_quantization(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + return self._quantizer.transform_for_annotation(model) def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup: - original_meta = model.meta - - self._quantizer.annotate(model) + # Save model and nodes meta before the annotation + original_meta = model.meta.copy() + node_name_vs_meta = {} + with torch.no_grad(): + for node in model.graph.nodes: + node_name_vs_meta[node.name] = node.meta.copy() + + model = self._quantizer.annotate(model) self._quantizer.validate(model) quantizer_setup = self.get_quantizer_config_from_anotated_model(model) - # Remove quantization annotations from the original model - quantization_annotation_key = "quantization_annotation" - for n in model.graph.nodes: - if hasattr(n, "meta") and quantization_annotation_key in n.meta: - del n.meta[quantization_annotation_key] - + # Recover original meta model.meta = original_meta + for node in model.graph.nodes: + node.meta = node_name_vs_meta[node.name] + return quantizer_setup @staticmethod