Skip to content

Commit

Permalink
Comments
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jun 28, 2023
1 parent 40ce658 commit bd3660d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 14 deletions.
14 changes: 4 additions & 10 deletions nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,22 +711,16 @@ def remove_nodes_from(self, nodes: List[NNCFNode]) -> None:
for node_key, node in self._nx_graph.nodes.items():
self._node_id_to_key_dict[node["id"]] = node_key

def find_matching_nodes(self, patterns: GraphPattern) -> List[NNCFNode]:
def find_matching_subgraphs(self, patterns: GraphPattern) -> List[List[NNCFNode]]:
"""
Returns nodes of matched pattern in patterns.
Returns subgraphs of matched pattern in patterns.
:param patterns: Instance of GraphPattern containing all patterns.
:return: Nodes that are matched patterns.
:return: List of subgraphs that are matching by pattern matching.
Subgraph is a ordered list of nodes of matched subgraph
The returned nodes order relies on DiGraphMatcher isomorphic subgraphs matching logic from networkX package.
DiGraphMatcher does not guarantee a specific order for returning isomorphic subgraphs.
"""
output = []
for matched_subgraph in find_subgraphs_matching_pattern(self._nx_graph, patterns):
for node_key in matched_subgraph:
output.append(self.get_node_by_key(node_key))
return output

def find_matching_subgraphs(self, patterns: GraphPattern) -> List[List[NNCFNode]]:
output = []
for matched_subgraph in find_subgraphs_matching_pattern(self._nx_graph, patterns):
subgraph_list = []
Expand Down
8 changes: 5 additions & 3 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,11 @@ def _get_ignored_scope(self, inference_nncf_graph: NNCFGraph, ignored_patterns:
:param ignored_patterns: Ignored patterns.
:return: IgnoredScope with all node names matched ignored_patterns.
"""
nncf_node_names = [
nncf_node.node_name for nncf_node in inference_nncf_graph.find_matching_nodes(ignored_patterns)
]
nncf_node_names = []
for subgraph in inference_nncf_graph.find_matching_subgraphs(ignored_patterns):
for nncf_node in subgraph:
nncf_node_names.append(nncf_node.node_name)

return IgnoredScope(names=nncf_node_names)

def _get_quantizer_setup(
Expand Down
3 changes: 2 additions & 1 deletion tests/common/graph/test_graph_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,11 @@ def test_matches_with_any_pattern_node_type():

def test_not_match_edges_inside_pattern():
ref_graph = nx.DiGraph()
ref_graph.add_node("0")
ref_graph.add_node("0", **{GraphPattern.METATYPE_ATTR: "0"})
ref_graph.add_node("1", **{GraphPattern.METATYPE_ATTR: "a"})
ref_graph.add_node("2", **{GraphPattern.METATYPE_ATTR: "b"})
ref_graph.add_node("3", **{GraphPattern.METATYPE_ATTR: "c"})
ref_graph.add_edge("0", "1")
ref_graph.add_edge("1", "2")
ref_graph.add_edge("2", "3")
ref_graph.add_edge("1", "3")
Expand Down

0 comments on commit bd3660d

Please sign in to comment.