Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchFX] Export to torch.export.export_for_training #3075

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions nncf/experimental/torch/fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def _traverse_graph(
continue

visited.add(in_node.name)
# Do not traverse through constant users
# as a constant could be shared and
# some shared ops users could be redundant.
if in_node.op == "get_attr":
continue
input_nodes.extend(in_node.all_input_nodes)
input_nodes.extend(list(in_node.users))

Expand Down
4 changes: 3 additions & 1 deletion nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def get_edge_params(
if source_node.op in ("get_attr",):
tensor_shape = tuple(get_tensor_constant_from_node(source_node, model).shape)
elif "val" in source_node.meta:
if source_nncf_node.metatype is om.PTBatchNormMetatype:
if source_nncf_node.metatype is om.PTBatchNormMetatype and isinstance(
source_node.meta["val"], (tuple, list)
):
tensor = source_node.meta["val"][0]
elif source_nncf_node.metatype in [om.PTSplitMetatype, om.PTMaxMetatype, om.PTMinMetatype]:
tensor = source_node.meta["val"][output_idx]
Expand Down
14 changes: 9 additions & 5 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,17 @@ def shared_constants_unification_transformation(model: torch.fx.GraphModule):

:param model: Target Torch FX GraphModule
"""
prev_targets = {}
target_vs_constant = {}

for source_node in model.graph.nodes:
dist_node = list(source_node.users)
if source_node.target in prev_targets and source_node.op in ("get_attr",):
dist_node[0].replace_input_with(source_node, prev_targets[source_node.target])
if source_node.op != "get_attr" or not source_node.users:
continue

if source_node.target in target_vs_constant:
for user in list(source_node.users):
user.replace_input_with(source_node, target_vs_constant[source_node.target])
else:
prev_targets[source_node.target] = source_node
target_vs_constant[source_node.target] = source_node

model.graph.eliminate_dead_code()
model.recompile()
Expand Down Expand Up @@ -541,6 +544,7 @@ def _is_supported_batch_norm_for_training(node: torch.fx.Node):
Return True if the given node refers to an aten batch norm op QAT supports.
"""
supported_ops = [
torch.ops.aten.batch_norm.default,
torch.ops.aten._native_batch_norm_legit.default,
torch.ops.aten.cudnn_batch_norm.default,
torch.ops.aten.miopen_batch_norm.default,
Expand Down
12 changes: 12 additions & 0 deletions tests/torch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ def pytest_addoption(parser: Parser):
"reference .dot files will be regenerated "
"using the current state of the repository.",
)
parser.addoption(
"--regen-json",
action="store_true",
default=False,
help="If specified, the "
"reference .json files will be regenerated "
"using the current state of the repository.",
)
parser.addoption(
"--torch-home", type=str, default=None, help="Path to cached test models, downloaded by torchvision"
)
Expand Down Expand Up @@ -120,6 +128,10 @@ def pytest_configure(config: Config):
if regen_dot:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if I would like to update json files only?

Copy link
Collaborator

@anzr299 anzr299 Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regen_dot was added there by mistake, I have updated the suggestion. https://github.com/openvinotoolkit/nncf/pull/3075/files/68bb7b6f09bba89935ae44a7aff6e8f3bee237ea#r1841827740

os.environ["NNCF_TEST_REGEN_DOT"] = "1"

regen_json = config.getoption("--regen-json", False)
if regen_json:
os.environ["NNCF_TEST_REGEN_JSON"] = "1"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if regen_dot:
os.environ["NNCF_TEST_REGEN_DOT"] = "1"
regen_json = config.getoption("--regen-json", False)
if regen_json:
os.environ["NNCF_TEST_REGEN_JSON"] = "1"
for option, var in [("--regen-dot", "NNCF_TEST_REGEN_DOT"), ("--regen-json", "NNCF_TEST_REGEN_JSON")]:
regen = config.getoption(option, False)
if regen:
os.environ[env_var] = "1"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, please check


@pytest.fixture(scope="module")
def dataset_dir(request: FixtureRequest):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"0 _conv_w" [id=0, type=get_attr];
"1 add" [id=1, type=add];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 output" [id=4, type=output];
"0 _param_constant0" -> "1 add" [label="(1, 1, 1, 1)", style=solid];
"0 _conv_w" -> "1 add" [label="(1, 1, 1, 1)", style=solid];
"1 add" -> "3 conv2d" [label="(1, 1, 1, 1)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "4 output" [label="(1, 1, 1, 1)", style=solid];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"0 _conv_w" [id=0, type=get_attr];
"1 add" [id=1, type=add];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 output" [id=4, type=output];
"0 _param_constant0" -> "1 add" [label="(1, 1, 1, 1)", style=solid];
"0 _conv_w" -> "1 add" [label="(1, 1, 1, 1)", style=solid];
"1 add" -> "3 conv2d" [label="(1, 1, 1, 1)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "4 output" [label="(1, 1, 3, 3)", style=solid];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
strict digraph {
"0 _param_constant2" [id=0, type=get_attr];
"1 _param_constant3" [id=1, type=get_attr];
"2 conv2d_1_input" [id=2, type=input];
"3 conv2d_1" [id=3, type=conv2d];
"4 _tensor_constant0_1" [id=4, type=get_attr];
"0 conv_b_weight" [id=0, type=get_attr];
"1 conv_b_bias" [id=1, type=get_attr];
"2 bias" [id=2, type=get_attr];
"3 conv2d_1_input" [id=3, type=input];
"4 conv2d_1" [id=4, type=conv2d];
"5 add__1" [id=5, type=add_];
"6 output" [id=6, type=output];
"0 _param_constant2" -> "3 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"1 _param_constant3" -> "3 conv2d_1" [label="(3,)", style=solid];
"2 conv2d_1_input" -> "3 conv2d_1" [label=None, style=solid];
"3 conv2d_1" -> "5 add__1" [label="(1, 3, 3, 3)", style=solid];
"4 _tensor_constant0_1" -> "5 add__1" [label="(1,)", style=solid];
"0 conv_b_weight" -> "4 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"1 conv_b_bias" -> "4 conv2d_1" [label="(3,)", style=solid];
"2 bias" -> "5 add__1" [label="(1,)", style=solid];
"3 conv2d_1_input" -> "4 conv2d_1" [label=None, style=solid];
"4 conv2d_1" -> "5 add__1" [label="(1, 3, 3, 3)", style=solid];
"5 add__1" -> "6 output" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -1,38 +1,36 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"1 _param_constant1" [id=1, type=get_attr];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 _param_constant2" [id=4, type=get_attr];
"5 _param_constant3" [id=5, type=get_attr];
"6 conv2d_1" [id=6, type=conv2d];
"7 _tensor_constant0" [id=7, type=get_attr];
"8 add_" [id=8, type=add_];
"9 _tensor_constant0_1" [id=9, type=get_attr];
"10 add__1" [id=10, type=add_];
"11 add" [id=11, type=add];
"12 _param_constant4" [id=12, type=get_attr];
"13 _param_constant5" [id=13, type=get_attr];
"14 conv2d_2" [id=14, type=conv2d];
"15 _tensor_constant0_2" [id=15, type=get_attr];
"16 add_1" [id=16, type=add];
"17 output" [id=17, type=output];
"0 _param_constant0" -> "3 conv2d" [label="(3, 3, 1, 1)", style=solid];
"1 _param_constant1" -> "3 conv2d" [label="(3,)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "6 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"3 conv2d" -> "8 add_" [label="(1, 3, 3, 3)", style=solid];
"4 _param_constant2" -> "6 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"5 _param_constant3" -> "6 conv2d_1" [label="(3,)", style=solid];
"6 conv2d_1" -> "10 add__1" [label="(1, 3, 3, 3)", style=solid];
"7 _tensor_constant0" -> "8 add_" [label="(1,)", style=solid];
"8 add_" -> "11 add" [label="(1, 3, 3, 3)", style=solid];
"9 _tensor_constant0_1" -> "10 add__1" [label="(1,)", style=solid];
"10 add__1" -> "11 add" [label="(1, 3, 3, 3)", style=solid];
"10 add__1" -> "17 output" [label="(1, 3, 3, 3)", style=solid];
"11 add" -> "14 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"12 _param_constant4" -> "14 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"13 _param_constant5" -> "14 conv2d_2" [label="(3,)", style=solid];
"14 conv2d_2" -> "16 add_1" [label="(1, 3, 3, 3)", style=solid];
"15 _tensor_constant0_2" -> "16 add_1" [label="(1,)", style=solid];
"0 conv_a_weight" [id=0, type=get_attr];
"1 conv_a_bias" [id=1, type=get_attr];
"2 conv_b_weight" [id=2, type=get_attr];
"3 conv_b_bias" [id=3, type=get_attr];
"4 conv_c_weight" [id=4, type=get_attr];
"5 conv_c_bias" [id=5, type=get_attr];
"6 bias" [id=6, type=get_attr];
"7 conv2d_input" [id=7, type=input];
"8 conv2d" [id=8, type=conv2d];
"9 conv2d_1" [id=9, type=conv2d];
"10 add_" [id=10, type=add_];
"11 add__1" [id=11, type=add_];
"12 add" [id=12, type=add];
"13 conv2d_2" [id=13, type=conv2d];
"14 add_1" [id=14, type=add];
"15 output" [id=15, type=output];
"0 conv_a_weight" -> "8 conv2d" [label="(3, 3, 1, 1)", style=solid];
"1 conv_a_bias" -> "8 conv2d" [label="(3,)", style=solid];
"2 conv_b_weight" -> "9 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"3 conv_b_bias" -> "9 conv2d_1" [label="(3,)", style=solid];
"4 conv_c_weight" -> "13 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"5 conv_c_bias" -> "13 conv2d_2" [label="(3,)", style=solid];
"6 bias" -> "10 add_" [label="(1,)", style=solid];
"6 bias" -> "11 add__1" [label="(1,)", style=solid];
"6 bias" -> "14 add_1" [label="(1,)", style=solid];
"7 conv2d_input" -> "8 conv2d" [label=None, style=solid];
"8 conv2d" -> "9 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"8 conv2d" -> "10 add_" [label="(1, 3, 3, 3)", style=solid];
"9 conv2d_1" -> "11 add__1" [label="(1, 3, 3, 3)", style=solid];
"10 add_" -> "12 add" [label="(1, 3, 3, 3)", style=solid];
"11 add__1" -> "12 add" [label="(1, 3, 3, 3)", style=solid];
"11 add__1" -> "15 output" [label="(1, 3, 3, 3)", style=solid];
"12 add" -> "13 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"13 conv2d_2" -> "14 add_1" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"1 _param_constant1" [id=1, type=get_attr];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 _param_constant2" [id=4, type=get_attr];
"5 _param_constant3" [id=5, type=get_attr];
"6 conv2d_1" [id=6, type=conv2d];
"7 _tensor_constant0" [id=7, type=get_attr];
"0 conv_a_weight" [id=0, type=get_attr];
"1 conv_a_bias" [id=1, type=get_attr];
"2 conv_b_weight" [id=2, type=get_attr];
"3 conv_b_bias" [id=3, type=get_attr];
"4 bias" [id=4, type=get_attr];
"5 conv2d_input" [id=5, type=input];
"6 conv2d" [id=6, type=conv2d];
"7 conv2d_1" [id=7, type=conv2d];
"8 add_" [id=8, type=add_];
"9 _tensor_constant0_1" [id=9, type=get_attr];
"10 add__1" [id=10, type=add_];
"11 add" [id=11, type=add];
"12 output" [id=12, type=output];
"0 _param_constant0" -> "3 conv2d" [label="(3, 3, 1, 1)", style=solid];
"1 _param_constant1" -> "3 conv2d" [label="(3,)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "6 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"3 conv2d" -> "8 add_" [label="(1, 3, 3, 3)", style=solid];
"4 _param_constant2" -> "6 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"5 _param_constant3" -> "6 conv2d_1" [label="(3,)", style=solid];
"6 conv2d_1" -> "10 add__1" [label="(1, 3, 3, 3)", style=solid];
"7 _tensor_constant0" -> "8 add_" [label="(1,)", style=solid];
"8 add_" -> "11 add" [label="(1, 3, 3, 3)", style=solid];
"8 add_" -> "12 output" [label="(1, 3, 3, 3)", style=solid];
"9 _tensor_constant0_1" -> "10 add__1" [label="(1,)", style=solid];
"10 add__1" -> "11 add" [label="(1, 3, 3, 3)", style=solid];
"11 add" -> "12 output" [label="(1, 3, 3, 3)", style=solid];
"9 add__1" [id=9, type=add_];
"10 add" [id=10, type=add];
"11 output" [id=11, type=output];
"0 conv_a_weight" -> "6 conv2d" [label="(3, 3, 1, 1)", style=solid];
"1 conv_a_bias" -> "6 conv2d" [label="(3,)", style=solid];
"2 conv_b_weight" -> "7 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"3 conv_b_bias" -> "7 conv2d_1" [label="(3,)", style=solid];
"4 bias" -> "8 add_" [label="(1,)", style=solid];
"4 bias" -> "9 add__1" [label="(1,)", style=solid];
"5 conv2d_input" -> "6 conv2d" [label=None, style=solid];
"6 conv2d" -> "7 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"6 conv2d" -> "8 add_" [label="(1, 3, 3, 3)", style=solid];
"7 conv2d_1" -> "9 add__1" [label="(1, 3, 3, 3)", style=solid];
"8 add_" -> "10 add" [label="(1, 3, 3, 3)", style=solid];
"8 add_" -> "11 output" [label="(1, 3, 3, 3)", style=solid];
"9 add__1" -> "10 add" [label="(1, 3, 3, 3)", style=solid];
"10 add" -> "11 output" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"1 _param_constant1" [id=1, type=get_attr];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 _param_constant2" [id=4, type=get_attr];
"5 _param_constant3" [id=5, type=get_attr];
"6 conv2d_1_input" [id=6, type=input];
"7 conv2d_1" [id=7, type=conv2d];
"8 _tensor_constant0" [id=8, type=get_attr];
"0 conv_a_weight" [id=0, type=get_attr];
"1 conv_a_bias" [id=1, type=get_attr];
"2 conv_b_weight" [id=2, type=get_attr];
"3 conv_b_bias" [id=3, type=get_attr];
"4 bias" [id=4, type=get_attr];
"5 conv2d_input" [id=5, type=input];
"6 conv2d" [id=6, type=conv2d];
"7 conv2d_1_input" [id=7, type=input];
"8 conv2d_1" [id=8, type=conv2d];
"9 add_" [id=9, type=add_];
"10 _tensor_constant0_1" [id=10, type=get_attr];
"11 add__1" [id=11, type=add_];
"12 output" [id=12, type=output];
"0 _param_constant0" -> "3 conv2d" [label="(3, 3, 1, 1)", style=solid];
"1 _param_constant1" -> "3 conv2d" [label="(3,)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "9 add_" [label="(1, 3, 3, 3)", style=solid];
"4 _param_constant2" -> "7 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"5 _param_constant3" -> "7 conv2d_1" [label="(3,)", style=solid];
"6 conv2d_1_input" -> "7 conv2d_1" [label=None, style=solid];
"7 conv2d_1" -> "11 add__1" [label="(1, 3, 3, 3)", style=solid];
"8 _tensor_constant0" -> "9 add_" [label="(1,)", style=solid];
"9 add_" -> "12 output" [label="(1, 3, 3, 3)", style=solid];
"10 _tensor_constant0_1" -> "11 add__1" [label="(1,)", style=solid];
"11 add__1" -> "12 output" [label="(1, 3, 3, 3)", style=solid];
"10 add__1" [id=10, type=add_];
"11 output" [id=11, type=output];
"0 conv_a_weight" -> "6 conv2d" [label="(3, 3, 1, 1)", style=solid];
"1 conv_a_bias" -> "6 conv2d" [label="(3,)", style=solid];
"2 conv_b_weight" -> "8 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"3 conv_b_bias" -> "8 conv2d_1" [label="(3,)", style=solid];
"4 bias" -> "9 add_" [label="(1,)", style=solid];
"4 bias" -> "10 add__1" [label="(1,)", style=solid];
"5 conv2d_input" -> "6 conv2d" [label=None, style=solid];
"6 conv2d" -> "9 add_" [label="(1, 3, 3, 3)", style=solid];
"7 conv2d_1_input" -> "8 conv2d_1" [label=None, style=solid];
"8 conv2d_1" -> "10 add__1" [label="(1, 3, 3, 3)", style=solid];
"9 add_" -> "11 output" [label="(1, 3, 3, 3)", style=solid];
"10 add__1" -> "11 output" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -1,34 +1,33 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"1 _param_constant1" [id=1, type=get_attr];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 _param_constant2" [id=4, type=get_attr];
"5 _param_constant3" [id=5, type=get_attr];
"6 conv2d_1" [id=6, type=conv2d];
"7 _tensor_constant0" [id=7, type=get_attr];
"8 add_" [id=8, type=add_];
"9 _tensor_constant0_1" [id=9, type=get_attr];
"10 add__1" [id=10, type=add_];
"11 add" [id=11, type=add];
"12 _param_constant4" [id=12, type=get_attr];
"13 _param_constant5" [id=13, type=get_attr];
"14 conv2d_2" [id=14, type=conv2d];
"15 output" [id=15, type=output];
"0 _param_constant0" -> "3 conv2d" [label="(3, 3, 1, 1)", style=solid];
"1 _param_constant1" -> "3 conv2d" [label="(3,)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "6 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"3 conv2d" -> "8 add_" [label="(1, 3, 3, 3)", style=solid];
"4 _param_constant2" -> "6 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"5 _param_constant3" -> "6 conv2d_1" [label="(3,)", style=solid];
"6 conv2d_1" -> "10 add__1" [label="(1, 3, 3, 3)", style=solid];
"7 _tensor_constant0" -> "8 add_" [label="(1,)", style=solid];
"8 add_" -> "11 add" [label="(1, 3, 3, 3)", style=solid];
"9 _tensor_constant0_1" -> "10 add__1" [label="(1,)", style=solid];
"10 add__1" -> "11 add" [label="(1, 3, 3, 3)", style=solid];
"11 add" -> "14 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"12 _param_constant4" -> "14 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"13 _param_constant5" -> "14 conv2d_2" [label="(3,)", style=solid];
"14 conv2d_2" -> "15 output" [label="(1, 3, 3, 3)", style=solid];
"0 conv_a_weight" [id=0, type=get_attr];
"1 conv_a_bias" [id=1, type=get_attr];
"2 conv_b_weight" [id=2, type=get_attr];
"3 conv_b_bias" [id=3, type=get_attr];
"4 conv_c_weight" [id=4, type=get_attr];
"5 conv_c_bias" [id=5, type=get_attr];
"6 bias" [id=6, type=get_attr];
"7 conv2d_input" [id=7, type=input];
"8 conv2d" [id=8, type=conv2d];
"9 conv2d_1" [id=9, type=conv2d];
"10 add_" [id=10, type=add_];
"11 add__1" [id=11, type=add_];
"12 add" [id=12, type=add];
"13 conv2d_2" [id=13, type=conv2d];
"14 output" [id=14, type=output];
"0 conv_a_weight" -> "8 conv2d" [label="(3, 3, 1, 1)", style=solid];
"1 conv_a_bias" -> "8 conv2d" [label="(3,)", style=solid];
"2 conv_b_weight" -> "9 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"3 conv_b_bias" -> "9 conv2d_1" [label="(3,)", style=solid];
"4 conv_c_weight" -> "13 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"5 conv_c_bias" -> "13 conv2d_2" [label="(3,)", style=solid];
"6 bias" -> "10 add_" [label="(1,)", style=solid];
"6 bias" -> "11 add__1" [label="(1,)", style=solid];
"7 conv2d_input" -> "8 conv2d" [label=None, style=solid];
"8 conv2d" -> "9 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"8 conv2d" -> "10 add_" [label="(1, 3, 3, 3)", style=solid];
"9 conv2d_1" -> "11 add__1" [label="(1, 3, 3, 3)", style=solid];
"10 add_" -> "12 add" [label="(1, 3, 3, 3)", style=solid];
"11 add__1" -> "12 add" [label="(1, 3, 3, 3)", style=solid];
"12 add" -> "13 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"13 conv2d_2" -> "14 output" [label="(1, 3, 3, 3)", style=solid];
}
Loading
Loading