Skip to content

Commit

Permalink
Update pytorch.js (#842)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 29, 2024
1 parent dfd9795 commit 39dab9a
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 6 deletions.
60 changes: 60 additions & 0 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -9570,6 +9570,66 @@ python.Execution = class {
return new torch.jit.Source(data.peek(), path);
}
});
this.registerType('torch.jit.Environment', class {
constructor(method, resolver, b, next) {
this.method = method;
this.resolver = resolver;
this.b = b;
this.next = next;
this.value_table = new Map();
this.type_table = new Map();
}
getSugaredVar(ident, range, required) {
required = required || true;
let retval = this.findInAnyFrame(ident);
if (!retval) {
torch.jit.Environment.globals = torch.jit.Environment.globals || {
range: torch.jit.SpecialFormValue.create('prim::range')
};
if (ident in torch.jit.Environment.globals) {
retval = torch.jit.Environment.globals[ident];
}
}
if (!retval) {
//
}
if (!retval && required) {
throw new python.Error(`The name '${ident}' is not defined.`);
}
return retval;
}
findInAnyFrame(name) {
for (let runner = this; runner; runner = runner.next) {
const r = runner.findInThisFrame(name);
if (r) {
return r;
}
}
return null;
}
findInThisFrame(name) {
if (this.value_table.has(name)) {
return this.value_table.get(name);
}
if (this.type_table.has(name)) {
return this.insertLoad(name, this.type_table.get(name));
}
return null;
}
});
this.registerType('torch.jit.SugaredValue', class {
});
this.registerType('torch.jit.SimpleValue', class extends torch.jit.SugaredValue {
});
this.registerType('torch.jit.SpecialFormValue', class extends torch.jit.SugaredValue {
constructor(form) {
super();
this._form = form;
}
static create(form) {
return new torch.jit.SpecialFormValue(form);
}
});
this.registerType('torch.package.PackageImporter', class {
constructor(reader) {
this.zip_reader = reader;
Expand Down
26 changes: 20 additions & 6 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ pytorch.Graph = class {
}
}
for (const node of graph.nodes()) {
if (node.kind() === 'prim::ListConstruct' && node.inputs().every((value) => typeof value.value === 'number' && typeof value.value === 'string' && typeof value.value === 'boolean')) {
if (node.kind() === 'prim::ListConstruct' &&
node.inputs().every((value) => typeof value.value === 'number' && typeof value.value === 'string' && typeof value.value === 'boolean') &&
node.outputs().every((value) => value.uses().every((use) => use.user.kind() !== 'prim::CallMethod'))) {
node.outputs()[0].value = node.inputs().map((value) => value.value);
node.destroy();
}
Expand Down Expand Up @@ -1632,6 +1634,7 @@ pytorch.Execution = class extends python.Execution {
this._graph = this.invoke('torch.Graph', []);
this._constants = new Map();
this._values = new Map();
this.environment_stack = new torch.jit.Environment(/* */);
}

debug(file) {
Expand Down Expand Up @@ -2245,12 +2248,21 @@ pytorch.Execution = class extends python.Execution {

emitSugaredExpr(tree, n_binders, type_hint) {
const ast = this.ast;
if (tree instanceof ast.Var) {
//
const torch = this.torch;
if (tree instanceof ast.Name) { // TK_VAR
return this.environment_stack.getSugaredVar(tree.id);
} else if (tree instanceof ast.Attribute) {
//
} else if (tree instanceof ast.Apply) {
//
} else if (tree instanceof ast.Call) { // TK_APPLY
const apply = tree;
const sv = this.emitSugaredExpr(apply.func, 1);
const loc = apply.func;
if (sv instanceof torch.jit.SpecialFormValue) {
// return emitApplySpecialForm(sv.form(), apply, sv, type_hint);
}
const args = this.getNamedValues(apply.inputs(), true);
const kwargs = this.emitAttributes(apply.attributes());
return sv.call(loc, this.method, args, kwargs, n_binders);
} if (tree instanceof ast.Subscript) {
//
}
Expand Down Expand Up @@ -2439,7 +2451,9 @@ pytorch.Execution = class extends python.Execution {
}
/*
// const sv = this.expression(itrs[0], context);
const sv = this.emitSugaredExpr(itrs[0], 1);
*/
// const sv = this.emitSugaredExpr(itrs[0], 1);
/*
const iterable = sv.iter(range, method);
if (iterable.shouldEmitUnrolled()) {
this.emitUnrolledLoop(loc, emit_body, iterable, targets);
Expand Down

0 comments on commit 39dab9a

Please sign in to comment.