From 3753c465dbe3c73478a08488f0bada7f9e4dabd7 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Fri, 27 Dec 2024 01:08:20 -0800 Subject: [PATCH] Add torch.fx test file (#1414) --- source/python.js | 37 ++++++++++++++++++++++++++++++++++--- source/pytorch.js | 11 +++++------ test/models.json | 7 +++++++ 3 files changed, 46 insertions(+), 9 deletions(-) diff --git a/source/python.js b/source/python.js index 0d53bb96e7..86e3f3918e 100644 --- a/source/python.js +++ b/source/python.js @@ -5573,11 +5573,42 @@ python.Execution = class { } }); this.registerType('torch.fx.proxy.TracerBase', class {}); - this.registerType('torch.fx._symbolic_trace.Tracer', class extends torch.fx.proxy.TracerBase {}); + this.registerType('torch.fx._symbolic_trace.Tracer', class extends torch.fx.proxy.TracerBase { + trace(/* root, concrete_args */) { + this.graph = new torch.fx.graph.Graph(/* tracer_cls=tracer_cls */); + return this.graph; + // throw new python.Error('Not implemented'); + } + }); this.registerType('torch.fx.experimental.proxy_tensor.PythonKeyTracer', class extends torch.fx._symbolic_trace.Tracer {}); this.registerType('torch.fx.experimental.proxy_tensor._ModuleStackTracer', class extends torch.fx.experimental.proxy_tensor.PythonKeyTracer {}); - this.registerFunction('torch.fx.graph_module._deserialize_graph_module', (/* forward, body */) => { - return execution.invoke('torch.fx.graph_module.GraphModule', []); + this.registerFunction('torch.fx._lazy_graph_module._make_graph_module', (...args) => { + const graph_module_cls = args.pop() || torch.fx.graph_module.GraphModule; + return new graph_module_cls(...args); + }); + this.registerFunction('torch.fx.graph_module._deserialize_graph_module', (forward, body, graph_module_cls) => { + let tracer_cls = body.get('_tracer_cls'); + if (!tracer_cls) { + tracer_cls = torch.fx._symbolic_trace.Tracer; + } + const graphmodule_cls_name = body.get('_graphmodule_cls_name', 'GraphModule'); + const cls_tracer = tracer_cls; + const KeepModules = class extends cls_tracer { + is_leaf_module() { + return true; + } + }; + const com = {}; // _CodeOnlyModule(body) + const tracer_extras = body.get('_tracer_extras', new builtins.dict()); + const graph = new KeepModules().trace(com, tracer_extras); + graph._tracer_cls = tracer_cls; + const gm = torch.fx._lazy_graph_module._make_graph_module(com, graph, graphmodule_cls_name, graph_module_cls); + for (const [k, v] of body.items()) { + if (!builtins.hasattr(gm, k)) { + builtins.setattr(gm, k, v); + } + } + return gm; }); this.registerFunction('torch.fx.graph_module._forward_from_src', (src, globals /*, co_fields */) => { globals = { ...globals }; diff --git a/source/pytorch.js b/source/pytorch.js index 9bdce84095..8951a2ff4c 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -290,13 +290,12 @@ pytorch.Graph = class { this.inputs.push(argument); } } - /* - for (const output_spec of exported_program.graph_signature.user_outputs()) { - const value = values.map(output_spec); - const argument = new pytorch.Argument(output_spec, [value]); - this.outputs.push(argument); + } else if (torch && module instanceof torch.fx.graph_module.GraphModule) { + const graph = module.graph; + for (const obj of graph.nodes) { + const node = new pytorch.Node(execution, metadata, obj.name, null, obj, null, values); + this.nodes.push(node); } - */ } else if (pytorch.Utility.isTensor(module)) { const node = new pytorch.Node(execution, metadata, null, type, { value: module }); this.nodes.push(node); diff --git a/test/models.json b/test/models.json index 07dcab1640..cd571a6bf7 100644 --- a/test/models.json +++ b/test/models.json @@ -5319,6 +5319,13 @@ "assert": "model.graphs[0].nodes.length == 25", "link": "https://github.com/lutzroeder/netron/issues/281" }, + { + "type": "pytorch", + "target": "alexnet.fx.pth", + "source": "https://github.com/user-attachments/files/18259127/alexnet.fx.zip[alexnet.fx.pth]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/1414" + }, { "type": "pytorch", "target": "alexnet.pkl.pth.zip",