diff --git a/lib/model/nns/graph.js b/lib/model/nns/graph.js index 9987a3e7..ce7fe7b0 100644 --- a/lib/model/nns/graph.js +++ b/lib/model/nns/graph.js @@ -121,7 +121,8 @@ export default class ComputationalGraph { let s = 'digraph g {\n' for (let i = 0; i < this._nodes.length; i++) { const node = this.nodes[i] - s += ` l${i} [label="${node.layer.constructor.name}\\n${node.name}"];\n` + const label = node.layer.constructor.name + (node.name ? `\\n${node.name}` : '') + s += ` l${i} [label="${label}"];\n` for (const parent of node.parents) { s += ` l${parent.index} -> l${i};\n` } diff --git a/tests/lib/model/nns/graph.test.js b/tests/lib/model/nns/graph.test.js index 54d0a6a8..d9d1bd84 100644 --- a/tests/lib/model/nns/graph.test.js +++ b/tests/lib/model/nns/graph.test.js @@ -158,9 +158,9 @@ describe('Computational Graph', () => { test('toDot', () => { const graph = new ComputationalGraph() graph.add(Layer.fromObject({ type: 'input' })) - graph.add(Layer.fromObject({ type: 'tanh' })) + graph.add(Layer.fromObject({ type: 'tanh' }), 't') expect(graph.toDot()).toBe( - 'digraph g {\n l0 [label="InputLayer\\nundefined"];\n l1 [label="TanhLayer\\nundefined"];\n l0 -> l1;\n}' + 'digraph g {\n l0 [label="InputLayer"];\n l1 [label="TanhLayer\\nt"];\n l0 -> l1;\n}' ) })