Skip to content

Commit

Permalink
Add torch.fx test file (#1414)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 27, 2024
1 parent dd3158b commit 3753c46
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 9 deletions.
37 changes: 34 additions & 3 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -5573,11 +5573,42 @@ python.Execution = class {
}
});
this.registerType('torch.fx.proxy.TracerBase', class {});
this.registerType('torch.fx._symbolic_trace.Tracer', class extends torch.fx.proxy.TracerBase {});
this.registerType('torch.fx._symbolic_trace.Tracer', class extends torch.fx.proxy.TracerBase {
trace(/* root, concrete_args */) {
this.graph = new torch.fx.graph.Graph(/* tracer_cls=tracer_cls */);
return this.graph;
// throw new python.Error('Not implemented');
}
});
this.registerType('torch.fx.experimental.proxy_tensor.PythonKeyTracer', class extends torch.fx._symbolic_trace.Tracer {});
this.registerType('torch.fx.experimental.proxy_tensor._ModuleStackTracer', class extends torch.fx.experimental.proxy_tensor.PythonKeyTracer {});
this.registerFunction('torch.fx.graph_module._deserialize_graph_module', (/* forward, body */) => {
return execution.invoke('torch.fx.graph_module.GraphModule', []);
this.registerFunction('torch.fx._lazy_graph_module._make_graph_module', (...args) => {
const graph_module_cls = args.pop() || torch.fx.graph_module.GraphModule;
return new graph_module_cls(...args);
});
this.registerFunction('torch.fx.graph_module._deserialize_graph_module', (forward, body, graph_module_cls) => {
let tracer_cls = body.get('_tracer_cls');
if (!tracer_cls) {
tracer_cls = torch.fx._symbolic_trace.Tracer;
}
const graphmodule_cls_name = body.get('_graphmodule_cls_name', 'GraphModule');
const cls_tracer = tracer_cls;
const KeepModules = class extends cls_tracer {
is_leaf_module() {
return true;
}
};
const com = {}; // _CodeOnlyModule(body)
const tracer_extras = body.get('_tracer_extras', new builtins.dict());
const graph = new KeepModules().trace(com, tracer_extras);
graph._tracer_cls = tracer_cls;
const gm = torch.fx._lazy_graph_module._make_graph_module(com, graph, graphmodule_cls_name, graph_module_cls);
for (const [k, v] of body.items()) {
if (!builtins.hasattr(gm, k)) {
builtins.setattr(gm, k, v);
}
}
return gm;
});
this.registerFunction('torch.fx.graph_module._forward_from_src', (src, globals /*, co_fields */) => {
globals = { ...globals };
Expand Down
11 changes: 5 additions & 6 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,12 @@ pytorch.Graph = class {
this.inputs.push(argument);
}
}
/*
for (const output_spec of exported_program.graph_signature.user_outputs()) {
const value = values.map(output_spec);
const argument = new pytorch.Argument(output_spec, [value]);
this.outputs.push(argument);
} else if (torch && module instanceof torch.fx.graph_module.GraphModule) {
const graph = module.graph;
for (const obj of graph.nodes) {
const node = new pytorch.Node(execution, metadata, obj.name, null, obj, null, values);
this.nodes.push(node);
}
*/
} else if (pytorch.Utility.isTensor(module)) {
const node = new pytorch.Node(execution, metadata, null, type, { value: module });
this.nodes.push(node);
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -5319,6 +5319,13 @@
"assert": "model.graphs[0].nodes.length == 25",
"link": "https://github.com/lutzroeder/netron/issues/281"
},
{
"type": "pytorch",
"target": "alexnet.fx.pth",
"source": "https://github.com/user-attachments/files/18259127/alexnet.fx.zip[alexnet.fx.pth]",
"format": "PyTorch v1.6",
"link": "https://github.com/lutzroeder/netron/issues/1414"
},
{
"type": "pytorch",
"target": "alexnet.pkl.pth.zip",
Expand Down

0 comments on commit 3753c46

Please sign in to comment.