Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Common] port_id and parallel_edges_ids are added to the graph for analysis #2206

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
10 changes: 5 additions & 5 deletions nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,12 +607,12 @@ def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph
else:
attrs_edge["style"] = "solid"
label["shape"] = edge[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR]
label[
"ports"
] = f"{edge[NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR]}\u2192{edge[NNCFGraph.INPUT_PORT_ID_EDGE_ATTR]}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that you tried to make the representation prettier, but even in 2023 there's no guarantee that every tool can handle Unicode well enough. Use -> to represent an arrow, that would be safer.


if label:
if "shape" in label and len(label) == 1:
attrs_edge["label"] = label["shape"]
else:
attrs_edge["label"] = ", ".join((f"{k}:{v}" for k, v in label.items()))
attrs_edge["label"] = "\n".join((f"{k}:{v}" for k, v in label.items()))
out_graph.add_edge(u, v, **attrs_edge)
return out_graph

Expand All @@ -636,7 +636,7 @@ def _get_graph_for_visualization(self) -> nx.DiGraph:
style = "solid"
edge_label = (
f"{edge[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR]} \\n"
f"{edge[NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR]} -> {edge[NNCFGraph.INPUT_PORT_ID_EDGE_ATTR]}"
f"{edge[NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR]} \u2192 {edge[NNCFGraph.INPUT_PORT_ID_EDGE_ATTR]}"
)
out_graph.add_edge(u, v, label=edge_label, style=style)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ strict digraph {
"9 /Output_3_0" [id=9, type=Output_3];
"10 /Output_2_4_0" [id=10, type=output];
"11 /Output_3_1_0" [id=11, type=output];
"0 /Input_1_0" -> "1 /Split_1_0";
"1 /Split_1_0" -> "5 /Output_1_0";
"1 /Split_1_0" -> "6 /Output_2_1_0";
"1 /Split_1_0" -> "7 /Output_2_2_0";
"1 /Split_1_0" -> "8 /Output_2_3_0";
"1 /Split_1_0" -> "9 /Output_3_0";
"1 /Split_1_0" -> "10 /Output_2_4_0";
"1 /Split_1_0" -> "11 /Output_3_1_0" [label="parallel_input_port_ids:[2, 3, 4, 5, 6, 7, 8, 9]"];
"0 /Input_1_0" -> "1 /Split_1_0" [label="shape:[1, 1, 1, 1]\nports:0→0", style=solid];
"1 /Split_1_0" -> "5 /Output_1_0" [label="shape:[1, 1, 1, 1]\nports:0→0", style=solid];
"1 /Split_1_0" -> "6 /Output_2_1_0" [label="shape:[1, 1, 1, 1]\nports:1→0", style=solid];
"1 /Split_1_0" -> "7 /Output_2_2_0" [label="shape:[1, 1, 1, 1]\nports:1→0", style=solid];
"1 /Split_1_0" -> "8 /Output_2_3_0" [label="shape:[1, 1, 1, 1]\nports:1→0", style=solid];
"1 /Split_1_0" -> "9 /Output_3_0" [label="shape:[1, 1, 1, 1]\nports:2→0", style=solid];
"1 /Split_1_0" -> "10 /Output_2_4_0" [label="shape:[1, 1, 1, 1]\nports:1→15", style=solid];
"1 /Split_1_0" -> "11 /Output_3_1_0" [label="parallel_input_port_ids:[2, 3, 4, 5, 6, 7, 8, 9]\nshape:[1, 1, 1, 1]\nports:2→1", style=solid];
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ strict digraph {
"9 /Output_3_0" [id=9, type=Output_3];
"10 /Output_2_4_0" [id=10, type=output];
"11 /Output_3_1_0" [id=11, type=output];
"0 /Input_1_0" -> "1 /Split_1_0";
"1 /Split_1_0" -> "2 /Dropout_1_0";
"1 /Split_1_0" -> "3 /Dropout_2_0";
"1 /Split_1_0" -> "4 /Dropout_3_0";
"2 /Dropout_1_0" -> "5 /Output_1_0";
"3 /Dropout_2_0" -> "6 /Output_2_1_0";
"3 /Dropout_2_0" -> "7 /Output_2_2_0";
"3 /Dropout_2_0" -> "8 /Output_2_3_0";
"3 /Dropout_2_0" -> "10 /Output_2_4_0";
"4 /Dropout_3_0" -> "9 /Output_3_0";
"4 /Dropout_3_0" -> "11 /Output_3_1_0" [label="parallel_input_port_ids:[2, 3, 4, 5, 6, 7, 8, 9]"];
"0 /Input_1_0" -> "1 /Split_1_0" [label="shape:[1, 1, 1, 1]\nports:0→0", style=solid];
"1 /Split_1_0" -> "2 /Dropout_1_0" [label="shape:[1, 1, 1, 1]\nports:0→0", style=solid];
"1 /Split_1_0" -> "3 /Dropout_2_0" [label="shape:[1, 1, 1, 1]\nports:1→0", style=solid];
"1 /Split_1_0" -> "4 /Dropout_3_0" [label="shape:[1, 1, 1, 1]\nports:2→0", style=solid];
"2 /Dropout_1_0" -> "5 /Output_1_0" [label="shape:[1, 1, 1, 1]\nports:0→0", style=solid];
"3 /Dropout_2_0" -> "6 /Output_2_1_0" [label="shape:[1, 1, 1, 1]\nports:0→0", style=solid];
"3 /Dropout_2_0" -> "7 /Output_2_2_0" [label="shape:[1, 1, 1, 1]\nports:1→0", style=solid];
"3 /Dropout_2_0" -> "8 /Output_2_3_0" [label="shape:[1, 1, 1, 1]\nports:2→0", style=solid];
"3 /Dropout_2_0" -> "10 /Output_2_4_0" [label="shape:[1, 1, 1, 1]\nports:1→15", style=solid];
"4 /Dropout_3_0" -> "9 /Output_3_0" [label="shape:[1, 1, 1, 1]\nports:0→0", style=solid];
"4 /Dropout_3_0" -> "11 /Output_3_1_0" [label="parallel_input_port_ids:[2, 3, 4, 5, 6, 7, 8, 9]\nshape:[1, 1, 1, 1]\nports:1→1", style=solid];
}
2 changes: 1 addition & 1 deletion tests/common/quantization/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TestModes(Enum):
@pytest.mark.parametrize("mode", [TestModes.VALID, TestModes.WRONG_TENSOR_SHAPE, TestModes.WRONG_PARALLEL_EDGES])
def test_remove_nodes_and_reconnect_graph(mode: TestModes):
def _check_graphs(dot_file_name, nncf_graph) -> None:
nx_graph = nncf_graph.get_graph_for_structure_analysis()
nx_graph = nncf_graph.get_graph_for_structure_analysis(extended=True)
path_to_dot = DATA_ROOT / dot_file_name
compare_nx_graph_with_reference(nx_graph, path_to_dot, check_edge_attrs=True)

Expand Down
Loading