Skip to content

Commit

Permalink
Revert graph matching refactoring as node order bug is present
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jun 28, 2023
1 parent dadfae0 commit e5e0b31
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 40 deletions.
78 changes: 41 additions & 37 deletions nncf/common/graph/graph_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Set
from typing import List, Set

import networkx as nx
import networkx.algorithms.isomorphism as ism

from nncf.common.graph.patterns import GraphPattern


def is_nodes_degrees_match(
graph: nx.DiGraph, pattern_graph: nx.DiGraph, mapping: Dict[str, str], first_node: str, last_node: str
):
def is_subgraph_has_inner_outgoing_edges(
graph: nx.DiGraph, full_subgraph_with_non_pattern_nodes: List[str], pattern_subgraph: List[str]
) -> bool:
"""
Checks amount of input and output edges for each node pairs in isomorphic mapping is matching
except for precessors of the first node and successors of the last node.
Isomorphic subgraphs could not have different edges between nodes inside subgraphs,
but could have connections to other nodes in grpah.
Checks out whether the 'pattern_subgraph' has outgoing edges,
that aren't connected with nodes from full_subgraph_with_non_pattern_nodes.
Example:
(conv2d + BN + ReLU pattern):
...
Expand All @@ -39,27 +37,30 @@ def is_nodes_degrees_match(
|
...
:param graph: The model graph.
:param pattern_graph: The pattern graph.
:param mapping: Mapping between graph nodes and pattern graph nodes.
:param first_node: Node key for starting node in matched subgraph.
:param last_node: Node key for ending node in matched subgraph.
:return: True if amount of input and output edges for each node pairs in isomorphic
mapping is matching except for presestors of the first node and successors of the last node,
False otherwise.
:param full_subgraph_with_non_pattern_nodes: A subgraph of the model graph including the nodes outside the pattern.
:param pattern_subgraph: A subgraph of the model.
:return: True if the subgraph contains outgoing edges starting not from the last node,
False - otherwise.
"""
for graph_key, pattern_key in mapping.items():
for attr in ["pred", "succ"]:
if graph_key == first_node and attr == "pred":
continue
if graph_key == last_node and attr == "succ":
continue

def _len(_graph, _key):
return len(getattr(_graph, attr)[_key].keys())

if not _len(graph, graph_key) == _len(pattern_graph, pattern_key):
return False
return True
first_node = pattern_subgraph[0]
last_node = pattern_subgraph[-1]
for node_key in pattern_subgraph:
if node_key == last_node:
predecessors = list(graph.pred[node_key].keys())
if any(predecessor not in full_subgraph_with_non_pattern_nodes for predecessor in predecessors):
return True
elif node_key == first_node:
successors = list(graph.succ[node_key].keys())
if any(successor not in full_subgraph_with_non_pattern_nodes for successor in successors):
return True
else:
successors = list(graph.succ[node_key].keys())
predecessors = list(graph.pred[node_key].keys())
if any(successors_key not in full_subgraph_with_non_pattern_nodes for successors_key in successors):
return True
if any(predecessor not in full_subgraph_with_non_pattern_nodes for predecessor in predecessors):
return True
return False


def find_subgraphs_matching_pattern(graph: nx.DiGraph, pattern_graph: GraphPattern) -> List[List[str]]:
Expand Down Expand Up @@ -113,9 +114,7 @@ def sort_patterns(pattern: nx.DiGraph):
"""
pattern_len = len(pattern)
for node in pattern.nodes:
if GraphPattern.NON_PATTERN_NODE_TYPE in pattern_graph.graph.nodes.get(node).get(
GraphPattern.METATYPE_ATTR, []
):
if GraphPattern.NON_PATTERN_NODE_TYPE in pattern_graph.graph.nodes.get(node)[GraphPattern.METATYPE_ATTR]:
pattern_len -= 1
return pattern_len

Expand All @@ -132,19 +131,24 @@ def sort_patterns(pattern: nx.DiGraph):
nx.lexicographical_topological_sort(graph.subgraph(subgraph), key=lambda x: int(x.split()[0]))
)

full_subgraph_with_non_pattern_nodes = pattern_subgraph[:]
outside_pattern_nodes = []

# If some nodes are outside the pattern - remove them from pattern_subgraph

for node, pattern_node_id in matcher.mapping.items():
pattern_node = pattern_graph.graph.nodes[pattern_node_id]
pattern_node_types = pattern_node.get(GraphPattern.METATYPE_ATTR, [])
pattern_node_types = pattern_node.get(GraphPattern.METATYPE_ATTR)
if GraphPattern.NON_PATTERN_NODE_TYPE in pattern_node_types:
pattern_subgraph.remove(node)
outside_pattern_nodes.append(node)
for node in outside_pattern_nodes:
pattern_subgraph.remove(node)

if any(node in visited_nodes for node in pattern_subgraph):
is_visited_node = any(node in visited_nodes for node in pattern_subgraph)
if is_visited_node:
continue

if not is_nodes_degrees_match(graph, pattern, matcher.mapping, pattern_subgraph[0], pattern_subgraph[-1]):
if is_subgraph_has_inner_outgoing_edges(graph, full_subgraph_with_non_pattern_nodes, pattern_subgraph):
continue

visited_nodes.update(pattern_subgraph)
subgraphs.append(pattern_subgraph)

Expand Down
6 changes: 3 additions & 3 deletions nncf/tensorflow/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,10 +793,10 @@ def _get_conv_layer_attributes(layer: tf.keras.layers.Layer, is_depthwise: bool

return ConvolutionLayerAttributes(
weight_requires_grad=layer.trainable,
in_features=in_channels,
out_features=out_channels,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
strides=strides,
stride=strides,
dilations=dilations,
groups=groups,
transpose=transpose,
Expand Down
2 changes: 2 additions & 0 deletions tests/torch/ptq/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def get_depthwise_conv_nncf_graph() -> NNCFGraphToTestDepthwiseConv:
in_channels=3,
out_channels=3,
dilations=1,
kernel_size=(1, 1),
stride=(1, 1),
groups=1,
transpose=False,
padding_values=(1, 1),
Expand Down

0 comments on commit e5e0b31

Please sign in to comment.