Skip to content

Commit

Permalink
Add check attr weight_port_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Apr 20, 2024
1 parent 3a74dce commit 136756e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 2 additions & 3 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.hardware.config import PTHWConfig
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.quantization.layers import QUANTIZATION_MODULES
Expand Down Expand Up @@ -347,6 +348,4 @@ def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> List[str]:

@staticmethod
def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]:
return [
node for node in nncf_graph.get_all_nodes() if isinstance(node.layer_attributes, WeightedLayerAttributes)
]
return [node for node in nncf_graph.get_all_nodes() if get_weight_tensor_port_ids(node, nncf_graph)]
2 changes: 2 additions & 0 deletions nncf/torch/model_graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[int]:
:param graph: The NNCF graph.
:return: List of ports with weights.
"""
if not hasattr(node, "metatype") or not hasattr(node.metatype, "weight_port_ids"):
return []
weight_port_ids = []
for edge in graph.get_input_edges(node):
if edge.input_port_id in node.metatype.weight_port_ids:
Expand Down

0 comments on commit 136756e

Please sign in to comment.