From 136756ee3f3fd9173f5f19a790e5c23d9bf1984c Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Sat, 20 Apr 2024 01:01:49 +0300 Subject: [PATCH] Add check attr weight_port_ids --- nncf/quantization/algorithms/min_max/torch_backend.py | 5 ++--- nncf/torch/model_graph_manager.py | 2 ++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index a1943d893e7..e325c6c7f70 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -41,6 +41,7 @@ from nncf.torch.graph.transformations.commands import PTInsertionCommand from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.hardware.config import PTHWConfig +from nncf.torch.model_graph_manager import get_weight_tensor_port_ids from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT from nncf.torch.quantization.layers import QUANTIZATION_MODULES @@ -347,6 +348,4 @@ def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> List[str]: @staticmethod def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]: - return [ - node for node in nncf_graph.get_all_nodes() if isinstance(node.layer_attributes, WeightedLayerAttributes) - ] + return [node for node in nncf_graph.get_all_nodes() if get_weight_tensor_port_ids(node, nncf_graph)] diff --git a/nncf/torch/model_graph_manager.py b/nncf/torch/model_graph_manager.py index 0d7120ba07b..0d40f655f3c 100644 --- a/nncf/torch/model_graph_manager.py +++ b/nncf/torch/model_graph_manager.py @@ -244,6 +244,8 @@ def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[int]: :param graph: The NNCF graph. :return: List of ports with weights. """ + if not hasattr(node, "metatype") or not hasattr(node.metatype, "weight_port_ids"): + return [] weight_port_ids = [] for edge in graph.get_input_edges(node): if edge.input_port_id in node.metatype.weight_port_ids: