diff --git a/source/python.js b/source/python.js index f98c3c6cbf..6c8e2129cb 100644 --- a/source/python.js +++ b/source/python.js @@ -9570,6 +9570,66 @@ python.Execution = class { return new torch.jit.Source(data.peek(), path); } }); + this.registerType('torch.jit.Environment', class { + constructor(method, resolver, b, next) { + this.method = method; + this.resolver = resolver; + this.b = b; + this.next = next; + this.value_table = new Map(); + this.type_table = new Map(); + } + getSugaredVar(ident, range, required) { + required = required || true; + let retval = this.findInAnyFrame(ident); + if (!retval) { + torch.jit.Environment.globals = torch.jit.Environment.globals || { + range: torch.jit.SpecialFormValue.create('prim::range') + }; + if (ident in torch.jit.Environment.globals) { + retval = torch.jit.Environment.globals[ident]; + } + } + if (!retval) { + // + } + if (!retval && required) { + throw new python.Error(`The name '${ident}' is not defined.`); + } + return retval; + } + findInAnyFrame(name) { + for (let runner = this; runner; runner = runner.next) { + const r = runner.findInThisFrame(name); + if (r) { + return r; + } + } + return null; + } + findInThisFrame(name) { + if (this.value_table.has(name)) { + return this.value_table.get(name); + } + if (this.type_table.has(name)) { + return this.insertLoad(name, this.type_table.get(name)); + } + return null; + } + }); + this.registerType('torch.jit.SugaredValue', class { + }); + this.registerType('torch.jit.SimpleValue', class extends torch.jit.SugaredValue { + }); + this.registerType('torch.jit.SpecialFormValue', class extends torch.jit.SugaredValue { + constructor(form) { + super(); + this._form = form; + } + static create(form) { + return new torch.jit.SpecialFormValue(form); + } + }); this.registerType('torch.package.PackageImporter', class { constructor(reader) { this.zip_reader = reader; diff --git a/source/pytorch.js b/source/pytorch.js index 04d3e35b0c..b3d5f37c9c 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -168,7 +168,9 @@ pytorch.Graph = class { } } for (const node of graph.nodes()) { - if (node.kind() === 'prim::ListConstruct' && node.inputs().every((value) => typeof value.value === 'number' && typeof value.value === 'string' && typeof value.value === 'boolean')) { + if (node.kind() === 'prim::ListConstruct' && + node.inputs().every((value) => typeof value.value === 'number' && typeof value.value === 'string' && typeof value.value === 'boolean') && + node.outputs().every((value) => value.uses().every((use) => use.user.kind() !== 'prim::CallMethod'))) { node.outputs()[0].value = node.inputs().map((value) => value.value); node.destroy(); } @@ -1632,6 +1634,7 @@ pytorch.Execution = class extends python.Execution { this._graph = this.invoke('torch.Graph', []); this._constants = new Map(); this._values = new Map(); + this.environment_stack = new torch.jit.Environment(/* */); } debug(file) { @@ -2245,12 +2248,21 @@ pytorch.Execution = class extends python.Execution { emitSugaredExpr(tree, n_binders, type_hint) { const ast = this.ast; - if (tree instanceof ast.Var) { - // + const torch = this.torch; + if (tree instanceof ast.Name) { // TK_VAR + return this.environment_stack.getSugaredVar(tree.id); } else if (tree instanceof ast.Attribute) { // - } else if (tree instanceof ast.Apply) { - // + } else if (tree instanceof ast.Call) { // TK_APPLY + const apply = tree; + const sv = this.emitSugaredExpr(apply.func, 1); + const loc = apply.func; + if (sv instanceof torch.jit.SpecialFormValue) { + // return emitApplySpecialForm(sv.form(), apply, sv, type_hint); + } + const args = this.getNamedValues(apply.inputs(), true); + const kwargs = this.emitAttributes(apply.attributes()); + return sv.call(loc, this.method, args, kwargs, n_binders); } if (tree instanceof ast.Subscript) { // } @@ -2439,7 +2451,9 @@ pytorch.Execution = class extends python.Execution { } /* // const sv = this.expression(itrs[0], context); - const sv = this.emitSugaredExpr(itrs[0], 1); + */ + // const sv = this.emitSugaredExpr(itrs[0], 1); + /* const iterable = sv.iter(range, method); if (iterable.shouldEmitUnrolled()) { this.emitUnrolledLoop(loc, emit_body, iterable, targets);