From 627fa7ecca502a0a73fa748c5cdeeef252cac53d Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sat, 28 Dec 2024 14:39:36 -0800 Subject: [PATCH] Update torch.fx test file (#1414) --- source/python.js | 90 +++++++++++++++++++++++++++++++++++++++-------- source/pytorch.js | 8 ++--- test/models.json | 2 +- 3 files changed, 78 insertions(+), 22 deletions(-) diff --git a/source/python.js b/source/python.js index 6c78fabb6d..f98c3c6cbf 100644 --- a/source/python.js +++ b/source/python.js @@ -4970,6 +4970,26 @@ python.Execution = class { children() { return this._modules.values(); } + named_modules(memo, prefix, remove_duplicate) { + memo = memo || new Set(); + prefix = prefix || ''; + const modules = new builtins.dict(); + if (!memo.has(this)) { + if (remove_duplicate) { + memo.add(this); + } + modules.set(prefix, this); + for (const [name, module] of this._modules.items()) { + if (module) { + const submodule_prefix = `${prefix}${(prefix ? '.' : '')}${name}`; + for (const [k, v] of module.named_modules(memo, submodule_prefix, remove_duplicate)) { + modules.set(k, v); + } + } + } + } + return modules; + } named_children() { return this._modules; } @@ -4991,6 +5011,9 @@ python.Execution = class { } return this._buffers; } + _get_name() { + return this.__class__.__name__; + } }); torch.nn.Module = torch.nn.modules.module.Module; torch.nn.modules.Module = torch.nn.modules.module.Module; @@ -5392,12 +5415,14 @@ python.Execution = class { this.registerType('torch.nn.modules.pooling.FractionalMaxPool2d', class {}); this.registerType('torch.nn.modules.pooling.LPPool1d', class {}); this.registerType('torch.nn.modules.pooling.LPPool2d', class {}); - this.registerType('torch.nn.modules.pooling.MaxPool1d', class {}); - this.registerType('torch.nn.modules.pooling.MaxPool2d', class {}); - this.registerType('torch.nn.modules.pooling.MaxPool3d', class {}); - this.registerType('torch.nn.modules.pooling.MaxUnpool1d', class {}); - this.registerType('torch.nn.modules.pooling.MaxUnpool2d', class {}); - this.registerType('torch.nn.modules.pooling.MaxUnpool3d', class {}); + this.registerType('torch.nn.modules.pooling._MaxPoolNd', class extends torch.nn.modules.module.Module {}); + this.registerType('torch.nn.modules.pooling.MaxPool1d', class extends torch.nn.modules.pooling._MaxPoolNd {}); + this.registerType('torch.nn.modules.pooling.MaxPool2d', class extends torch.nn.modules.pooling._MaxPoolNd {}); + this.registerType('torch.nn.modules.pooling.MaxPool3d', class extends torch.nn.modules.pooling._MaxPoolNd {}); + this.registerType('torch.nn.modules.pooling._MaxUnpoolNd', class extends torch.nn.modules.module.Module {}); + this.registerType('torch.nn.modules.pooling.MaxUnpool1d', class extends torch.nn.modules.pooling._MaxUnpoolNd {}); + this.registerType('torch.nn.modules.pooling.MaxUnpool2d', class extends torch.nn.modules.pooling._MaxUnpoolNd {}); + this.registerType('torch.nn.modules.pooling.MaxUnpool3d', class extends torch.nn.modules.pooling._MaxUnpoolNd {}); this.registerType('torch.nn.modules.rnn.GRU', class {}); this.registerType('torch.nn.modules.rnn.GRUCell', class {}); this.registerType('torch.nn.modules.rnn.LSTM', class {}); @@ -5561,17 +5586,44 @@ python.Execution = class { this.registerType('torch.utils.data.sampler.RandomSampler', class {}); this.registerType('torch.utils.data.sampler.SequentialSampler', class {}); this.registerType('torch.utils.data.sampler.SubsetRandomSampler', class {}); + torch.nn.Sequential = torch.nn.modules.container.Sequential; this.registerType('torch.fx.experimental.symbolic_shapes.ShapeEnv', class { create_symintnode(/* sym, hint, source */) { return new torch.SymInt(); } }); - this.registerType('torch.fx.proxy.TracerBase', class {}); + this.registerType('torch.fx.proxy.TracerBase', class { + constructor() { + this.traced_func_name = 'forward'; + } + }); this.registerType('torch.fx._symbolic_trace.Tracer', class extends torch.fx.proxy.TracerBase { - trace(/* root, concrete_args */) { - this.graph = new torch.fx.graph.Graph(/* tracer_cls=tracer_cls */); + trace(root /*, concrete_args */) { + let fn = null; + if (root instanceof torch.nn.Module) { + // torch.fx._lazy_graph_module._LazyGraphModule.force_recompile(root) + this.root = root; + fn = builtins.getattr(new builtins.type(root), this.traced_func_name); + this.root_module_name = root._get_name(); + this.submodule_paths = new builtins.dict(root.named_modules()); + } else { + this.root = new torch.nn.Module(); + fn = root; + } + const tracer_cls = builtins.getattr(this, '__class__', null); + this.graph = new torch.fx.graph.Graph(null, tracer_cls); + if (builtins.hasattr(this, '__code__')) { + const code = fn.__code__; + this.graph._co_fields = { + co_name: code.co_name, + co_filename: code.co_filename, + co_firstlineno: code.co_firstlineno, + }; + } return this.graph; - // throw new python.Error('Not implemented'); + } + is_leaf_module(m /*, module_qualified_name */) { + return (m.__module__.startsWith('torch.nn') || m.__module__.startsWith('torch.ao.nn')) && m instanceof torch.nn.Sequential === false; } }); this.registerType('torch.fx.experimental.proxy_tensor.PythonKeyTracer', class extends torch.fx._symbolic_trace.Tracer {}); @@ -5592,7 +5644,7 @@ python.Execution = class { return true; } }; - const com = {}; // _CodeOnlyModule(body) + const com = new torch.fx.graph_module._CodeOnlyModule(body); const tracer_extras = body.get('_tracer_extras', new builtins.dict()); const graph = new KeepModules().trace(com, tracer_extras); graph._tracer_cls = tracer_cls; @@ -5762,14 +5814,14 @@ python.Execution = class { torch.fx.Node = torch.fx.node.Node; torch.fx.graph.Node = torch.fx.node.Node; this.registerType('torch.fx.graph.Graph', class { - constructor() { + constructor(owning_module, tracer_cls, tracer_extras) { this._root = new torch.fx.node.Node(self, '', 'root', '', new builtins.list(), new builtins.dict()); this._used_names = new Map(); this._len = 0; this._graph_namespace = new torch.fx.graph._Namespace(); - // this._owning_module = owning_module - // this._tracer_cls = tracer_cls - // this._tracer_extras = tracer_extras + this._owning_module = owning_module; + this._tracer_cls = tracer_cls; + this._tracer_extras = tracer_extras; // this._codegen = CodeGen() // this._co_fields = {} } @@ -5829,6 +5881,14 @@ python.Execution = class { return chars.join(''); } }); + this.registerType('torch.fx.graph_module._CodeOnlyModule', class extends torch.nn.modules.module.Module { + constructor(body) { + super(); + for (const [k, v] of body.items()) { + builtins.setattr(this, k, v); + } + } + }); this.registerType('torch.fx.graph_module.GraphModule', class extends torch.nn.modules.module.Module { constructor(root, graph, class_name) { super(); diff --git a/source/pytorch.js b/source/pytorch.js index 8951a2ff4c..04d3e35b0c 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -290,12 +290,6 @@ pytorch.Graph = class { this.inputs.push(argument); } } - } else if (torch && module instanceof torch.fx.graph_module.GraphModule) { - const graph = module.graph; - for (const obj of graph.nodes) { - const node = new pytorch.Node(execution, metadata, obj.name, null, obj, null, values); - this.nodes.push(node); - } } else if (pytorch.Utility.isTensor(module)) { const node = new pytorch.Node(execution, metadata, null, type, { value: module }); this.nodes.push(node); @@ -642,6 +636,8 @@ pytorch.Node = class { const argument = new pytorch.Argument('value', [value]); this.outputs.push(argument); } + } else if (obj.op === 'root') { + this.type = { name: obj.op }; } else { throw new pytorch.Error(`Unsupported node operation '${obj.op}'.`); } diff --git a/test/models.json b/test/models.json index cd571a6bf7..476c7d9ab6 100644 --- a/test/models.json +++ b/test/models.json @@ -6631,7 +6631,7 @@ "target": "torch_fx_sample.pt", "source": "https://github.com/lutzroeder/netron/files/12889841/torch_fx_sample.pt.zip[torch_fx_sample.pt]", "format": "PyTorch v1.6", - "link": "https://github.com/lutzroeder/netron/issues/720" + "link": "https://github.com/lutzroeder/netron/issues/1414" }, { "type": "pytorch",