Skip to content

Commit

Permalink
Update torch.fx test file (#1414)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 28, 2024
1 parent d162347 commit 627fa7e
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 22 deletions.
90 changes: 75 additions & 15 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -4970,6 +4970,26 @@ python.Execution = class {
children() {
return this._modules.values();
}
named_modules(memo, prefix, remove_duplicate) {
memo = memo || new Set();
prefix = prefix || '';
const modules = new builtins.dict();
if (!memo.has(this)) {
if (remove_duplicate) {
memo.add(this);
}
modules.set(prefix, this);
for (const [name, module] of this._modules.items()) {
if (module) {
const submodule_prefix = `${prefix}${(prefix ? '.' : '')}${name}`;
for (const [k, v] of module.named_modules(memo, submodule_prefix, remove_duplicate)) {
modules.set(k, v);
}
}
}
}
return modules;
}
named_children() {
return this._modules;
}
Expand All @@ -4991,6 +5011,9 @@ python.Execution = class {
}
return this._buffers;
}
_get_name() {
return this.__class__.__name__;
}
});
torch.nn.Module = torch.nn.modules.module.Module;
torch.nn.modules.Module = torch.nn.modules.module.Module;
Expand Down Expand Up @@ -5392,12 +5415,14 @@ python.Execution = class {
this.registerType('torch.nn.modules.pooling.FractionalMaxPool2d', class {});
this.registerType('torch.nn.modules.pooling.LPPool1d', class {});
this.registerType('torch.nn.modules.pooling.LPPool2d', class {});
this.registerType('torch.nn.modules.pooling.MaxPool1d', class {});
this.registerType('torch.nn.modules.pooling.MaxPool2d', class {});
this.registerType('torch.nn.modules.pooling.MaxPool3d', class {});
this.registerType('torch.nn.modules.pooling.MaxUnpool1d', class {});
this.registerType('torch.nn.modules.pooling.MaxUnpool2d', class {});
this.registerType('torch.nn.modules.pooling.MaxUnpool3d', class {});
this.registerType('torch.nn.modules.pooling._MaxPoolNd', class extends torch.nn.modules.module.Module {});
this.registerType('torch.nn.modules.pooling.MaxPool1d', class extends torch.nn.modules.pooling._MaxPoolNd {});
this.registerType('torch.nn.modules.pooling.MaxPool2d', class extends torch.nn.modules.pooling._MaxPoolNd {});
this.registerType('torch.nn.modules.pooling.MaxPool3d', class extends torch.nn.modules.pooling._MaxPoolNd {});
this.registerType('torch.nn.modules.pooling._MaxUnpoolNd', class extends torch.nn.modules.module.Module {});
this.registerType('torch.nn.modules.pooling.MaxUnpool1d', class extends torch.nn.modules.pooling._MaxUnpoolNd {});
this.registerType('torch.nn.modules.pooling.MaxUnpool2d', class extends torch.nn.modules.pooling._MaxUnpoolNd {});
this.registerType('torch.nn.modules.pooling.MaxUnpool3d', class extends torch.nn.modules.pooling._MaxUnpoolNd {});
this.registerType('torch.nn.modules.rnn.GRU', class {});
this.registerType('torch.nn.modules.rnn.GRUCell', class {});
this.registerType('torch.nn.modules.rnn.LSTM', class {});
Expand Down Expand Up @@ -5561,17 +5586,44 @@ python.Execution = class {
this.registerType('torch.utils.data.sampler.RandomSampler', class {});
this.registerType('torch.utils.data.sampler.SequentialSampler', class {});
this.registerType('torch.utils.data.sampler.SubsetRandomSampler', class {});
torch.nn.Sequential = torch.nn.modules.container.Sequential;
this.registerType('torch.fx.experimental.symbolic_shapes.ShapeEnv', class {
create_symintnode(/* sym, hint, source */) {
return new torch.SymInt();
}
});
this.registerType('torch.fx.proxy.TracerBase', class {});
this.registerType('torch.fx.proxy.TracerBase', class {
constructor() {
this.traced_func_name = 'forward';
}
});
this.registerType('torch.fx._symbolic_trace.Tracer', class extends torch.fx.proxy.TracerBase {
trace(/* root, concrete_args */) {
this.graph = new torch.fx.graph.Graph(/* tracer_cls=tracer_cls */);
trace(root /*, concrete_args */) {
let fn = null;
if (root instanceof torch.nn.Module) {
// torch.fx._lazy_graph_module._LazyGraphModule.force_recompile(root)
this.root = root;
fn = builtins.getattr(new builtins.type(root), this.traced_func_name);
this.root_module_name = root._get_name();
this.submodule_paths = new builtins.dict(root.named_modules());
} else {
this.root = new torch.nn.Module();
fn = root;
}
const tracer_cls = builtins.getattr(this, '__class__', null);
this.graph = new torch.fx.graph.Graph(null, tracer_cls);
if (builtins.hasattr(this, '__code__')) {
const code = fn.__code__;
this.graph._co_fields = {
co_name: code.co_name,
co_filename: code.co_filename,
co_firstlineno: code.co_firstlineno,
};
}
return this.graph;
// throw new python.Error('Not implemented');
}
is_leaf_module(m /*, module_qualified_name */) {
return (m.__module__.startsWith('torch.nn') || m.__module__.startsWith('torch.ao.nn')) && m instanceof torch.nn.Sequential === false;
}
});
this.registerType('torch.fx.experimental.proxy_tensor.PythonKeyTracer', class extends torch.fx._symbolic_trace.Tracer {});
Expand All @@ -5592,7 +5644,7 @@ python.Execution = class {
return true;
}
};
const com = {}; // _CodeOnlyModule(body)
const com = new torch.fx.graph_module._CodeOnlyModule(body);
const tracer_extras = body.get('_tracer_extras', new builtins.dict());
const graph = new KeepModules().trace(com, tracer_extras);
graph._tracer_cls = tracer_cls;
Expand Down Expand Up @@ -5762,14 +5814,14 @@ python.Execution = class {
torch.fx.Node = torch.fx.node.Node;
torch.fx.graph.Node = torch.fx.node.Node;
this.registerType('torch.fx.graph.Graph', class {
constructor() {
constructor(owning_module, tracer_cls, tracer_extras) {
this._root = new torch.fx.node.Node(self, '', 'root', '', new builtins.list(), new builtins.dict());
this._used_names = new Map();
this._len = 0;
this._graph_namespace = new torch.fx.graph._Namespace();
// this._owning_module = owning_module
// this._tracer_cls = tracer_cls
// this._tracer_extras = tracer_extras
this._owning_module = owning_module;
this._tracer_cls = tracer_cls;
this._tracer_extras = tracer_extras;
// this._codegen = CodeGen()
// this._co_fields = {}
}
Expand Down Expand Up @@ -5829,6 +5881,14 @@ python.Execution = class {
return chars.join('');
}
});
this.registerType('torch.fx.graph_module._CodeOnlyModule', class extends torch.nn.modules.module.Module {
constructor(body) {
super();
for (const [k, v] of body.items()) {
builtins.setattr(this, k, v);
}
}
});
this.registerType('torch.fx.graph_module.GraphModule', class extends torch.nn.modules.module.Module {
constructor(root, graph, class_name) {
super();
Expand Down
8 changes: 2 additions & 6 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,6 @@ pytorch.Graph = class {
this.inputs.push(argument);
}
}
} else if (torch && module instanceof torch.fx.graph_module.GraphModule) {
const graph = module.graph;
for (const obj of graph.nodes) {
const node = new pytorch.Node(execution, metadata, obj.name, null, obj, null, values);
this.nodes.push(node);
}
} else if (pytorch.Utility.isTensor(module)) {
const node = new pytorch.Node(execution, metadata, null, type, { value: module });
this.nodes.push(node);
Expand Down Expand Up @@ -642,6 +636,8 @@ pytorch.Node = class {
const argument = new pytorch.Argument('value', [value]);
this.outputs.push(argument);
}
} else if (obj.op === 'root') {
this.type = { name: obj.op };
} else {
throw new pytorch.Error(`Unsupported node operation '${obj.op}'.`);
}
Expand Down
2 changes: 1 addition & 1 deletion test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -6631,7 +6631,7 @@
"target": "torch_fx_sample.pt",
"source": "https://github.com/lutzroeder/netron/files/12889841/torch_fx_sample.pt.zip[torch_fx_sample.pt]",
"format": "PyTorch v1.6",
"link": "https://github.com/lutzroeder/netron/issues/720"
"link": "https://github.com/lutzroeder/netron/issues/1414"
},
{
"type": "pytorch",
Expand Down

0 comments on commit 627fa7e

Please sign in to comment.