diff --git a/source/python.js b/source/python.js index b4bb165022..0840598ac8 100644 --- a/source/python.js +++ b/source/python.js @@ -5102,6 +5102,12 @@ python.Execution = class { getOperation(/* node */) { return null; } + aliasAnalysisKind() { + const schemaRef = this.schema(); + const alias_analysis = schemaRef.aliasAnalysis(); + torch._C.TORCH_CHECK(alias_analysis === 'FROM_SCHEMA' || !schemaRef.hasAnyAliasInfo()); + return alias_analysis; + } }); this.registerFunction('torch._C.getRegistry', () => { torch._C.r = torch._C.r || new torch._C.OperatorRegistry(); @@ -5180,11 +5186,237 @@ python.Execution = class { return this._schema; } }); - this.registerFunction('torch._C.EliminateDeadCode', (/* graph */) => { - + this.registerType('torch._C.DeadCodeEliminator', class { + constructor(graph, sideEffectPolicy) { + this._graph = graph; + this._sideEffectPolicy = sideEffectPolicy; + this._useAliasDb = true; + this._aliasDb = null; + this._memo = new Map(); + this._marked = new Set(); + this._liveValues = new Set(); + this._deleteCallback = () => {}; + } + run(block, recurse) { + this.eliminateDeadForkInputs(block, recurse); + this.mark(block.return_node()); + this.mark(block); + this._deleteCallback(this._liveValues); + this.sweep(block, recurse); + } + setDeleteCallback(deleteCallback) { + this._deleteCallback = deleteCallback; + } + eliminateDeadForkInputs(block, recurse) { + for (const node of block.nodes()) { + if (recurse) { + for (const sb of node.blocks()) { + this.eliminateDeadForkInputs(sb, recurse); + } + } + if (node.kind() !== 'prim::fork') { + continue; + } + const g = node.g("Subgraph"); + for (let i = 0; i < g.inputs().length; i++) { + if (!g.inputs()[i].hasUses()) { + g.eraseInput(i); + node.removeInput(i); + } + } + } + } + markReturnNode(node) { + if (this._marked.has(node)) { + return false; + } + torch._C.AT_ASSERT(node.owningBlock().return_node() === node); + const outerNode = node.owningBlock().owningNode(); + if (outerNode === null || outerNode.kind() === 'prim::Reverse') { + return this.mark(node); + } + if (outerNode.kind() === 'prim::Loop' || outerNode.kind() === 'c10::onnx::Loop') { + throw new python.Error('Not implemented.'); + /* + const loop = new torch._C.LoopView(outerNode); + for (const auto i : c10::irange(loop.carriedOutputs().size())) { + if (outerNode.kind() == c10::onnx::Loop) { + this._liveValues.add(loop.bodyCarriedOutputs()[i]); + continue; + } + auto innerInput = loop.bodyCarriedInputs()[i]; + auto innerOutput = loop.bodyCarriedOutputs()[i]; + auto outerOutput = loop.carriedOutputs()[i]; + if (liveValues_.count(outerOutput) || innerInput->hasUses()) { + this._liveValues.add(innerOutput); + } + } + this._liveValues.add(loop.nextCond()); + */ + } else { + torch._C.AT_ASSERT(outerNode.outputs().length === node.inputs().length); + for (let i = 0; i < outerNode.outputs().length; i++) { + const innerOutput = node.inputs()[i]; + const outerOutput = outerNode.outputs()[i]; + if (!this._liveValues.has(outerOutput)) { + this._liveValues.add(innerOutput); + } + } + } + this._marked.add(node); + return true; + } + mark(...args) { + if (args.length === 1 && args[0] instanceof torch.Block) { + const [block] = args; + let anyMarked = false; + for (const node of block.nodes()) { + if (this._sideEffectPolicy === 'DONT_DELETE_NODES_WITH_SIDE_EFFECTS' && this.hasSideEffects(node)) { + const marked = this.mark(node); + anyMarked = anyMarked || marked; + } + } + const marked = this.markReturnNode(block.return_node()); + anyMarked = anyMarked || marked; + for (const node of block.nodes()) { + if (node.kind() === 'prim::Loop') { + const marked = this.markLoop(node); + anyMarked = anyMarked || marked; + } else { + for (const subBlock of node.blocks()) { + const marked = this.mark(subBlock); + anyMarked = anyMarked || marked; + } + } + const marked = this.markIfLive(node); + anyMarked = anyMarked || marked; + } + return anyMarked; + } + if (args.length === 1 && args[0] instanceof torch.Node) { + const [node] = args; + if (this._marked.has(node)) { + return false; + } + this._marked.add(node); + let curNode = node; + while (curNode && curNode.owningBlock()) { + this.mark(curNode); + curNode = curNode.owningBlock().owningNode(); + } + for (const input of node.inputs()) { + if (!this._liveValues.has(input)) { + this._liveValues.add(input); + } + } + return true; + } + throw new python.Error('Not implemented.'); + } + markIfLive(node) { + for (const output of node.outputs()) { + if (this._liveValues.has(output)) { + return this.mark(node); + } + } + if (this._useAliasDb) { + if (this.getOrCreateAliasDb().writesToAlias(node, this._liveValues)) { + return this.mark(node); + } + } + return false; + } + sweep(block, recurse) { + const nodes = block.nodes().reverse(); + for (const node of nodes) { + this.removeDeadBlockOutputs(node); + this.removeDeadLoopOutputs(node); + if (recurse) { + for (const block of node.blocks()) { + this.sweep(block, true); + } + } + if (!this._marked.has(node) && !node.hasUses()) { + node.destroy(); + } + } + } + hasUntrackedMutation(node) { + if (!this._useAliasDb) { + if (node.kind() === 'prim::SetAttr') { + return true; + } + const schema = node.maybeSchema(); + return schema && schema.is_mutable(); + } + return this.getOrCreateAliasDb().writesToWildcard(node); + } + hasSideEffects(node) { + const it = this._memo.get(node); + if (it) { + return it; + } + const has_side_effects = node.hasSideEffects() || + node.blocks().some((b) => b.nodes().some((n) => this.hasSideEffects(n))) || + this.hasUntrackedMutation(node); + this._memo.set(node, has_side_effects); + return has_side_effects; + } + removeDeadBlockOutputs(node) { + if (node.kind() !== 'prim::If' && node.kind() !== 'prim::GradOf') { + return; + } + for (let i_1 = node.outputs().length; i_1 > 0; i_1--) { + const i = i_1 - 1; + if (!node.outputs()[i].hasUses()) { + node.eraseOutput(i); + for (const b of node.blocks()) { + b.eraseOutput(i); + } + } + } + } + removeDeadLoopOutputs() { + } + getOrCreateAliasDb() { + if (!this._aliasDb) { + this._aliasDb = new torch._C.AliasDb(this._graph); + } + return this._aliasDb; + } + }); + this.registerFunction('torch._C.EliminateDeadCode', (...args) => { + if (args.length === 1 && args[0] instanceof torch.Graph) { + const [graph] = args; + const sideEffectPolicy = 'DONT_DELETE_NODES_WITH_SIDE_EFFECTS'; + const worker = new torch._C.DeadCodeEliminator(graph, sideEffectPolicy); + worker.run(graph.block(), /*recurse=*/true); + } else { + throw new python.Error('Not implemented.'); + } + }); + this.registerFunction('torch._C.removeTupleNodes', () => { + }); + this.registerFunction('torch._C.LowerSimpleTuples', (...args) => { + if (args.length === 1 && args[0] instanceof torch.Graph) { + const [graph] = args; + torch._C.LowerSimpleTuples(graph.block()); + torch._C.EliminateDeadCode(graph); + } else if (args.length === 1 && args[0] instanceof torch.Block) { + const [block] = args; + for (const n of block.nodes()) { + torch._C.removeTupleNodes(n, /*must_remove_tuples*/ false); + for (const b of n.blocks()) { + torch._C.LowerSimpleTuples(b); + } + } + } else { + throw new python.Error('Not implemented.'); + } }); this.registerType('torch._C.ConstantPropagator', class { constructor(graph, aliasing_types, ignore_custom_classes) { + this._made_change = false; this._graph = graph; this._aliasing_types = aliasing_types; this._ignore_custom_classes = ignore_custom_classes; @@ -5196,6 +5428,8 @@ python.Execution = class { this.ConstantPropagation(this._graph.block()); return this._made_change; } + propagateNode(/* n */) { + } removeExtraIfOutputs(n) { torch._C.TORCH_CHECK(n.kind() === 'prim::If'); const [true_block, false_block] = n.blocks(); @@ -5217,11 +5451,40 @@ python.Execution = class { } i++; } - this._made_change |= initial_outputs !== true_block.outputs().length; + this._made_change = this._made_change || (initial_outputs !== true_block.outputs().length); guard.dispose(); } - supportedNode() { - return false; // not implemented. + noMutableValues(values) { + return values.every((v) => !torch._C.AliasDb.isMutableType(v)); + } + getOrCreateAliasDb() { + if (!this._aliasDb) { + this._aliasDb = new torch._C.AliasDb(this._graph); + } + return this._aliasDb; + } + supportedNode(n) { + torch._C.skip_list = torch._C.skip_list || new Set([ + 'prim::If', + 'prim::Loop', + 'prim::Closure', + 'prim::Constant', + 'prim::AutogradZero', + 'prim::Uninitialized', + 'prim::Guard', + 'prim::profile', + 'prim::profile_ivalue', + 'prim::unchecked_unwrap_optional', + 'prim::awaitable', + 'aten::dequantize' + ]); + let no_mutation = false; + if (this._aliasing_types) { + no_mutation = !this.getOrCreateAliasDb().hasWriters(n); + } else { + no_mutation = this.noMutableValues(n.inputs()) && this.noMutableValues(n.outputs()); + } + return no_mutation && !n.kind().startsWith('onnx::') && !torch._C.skip_list.has(n.kind()) && !n.isNondeterministic() && !n.hasSideEffects() && n.blocks().length === 0; } ConstantPropagation(...args) { if (args[0] instanceof torch.Graph) { @@ -5274,8 +5537,141 @@ python.Execution = class { } return made_change; }); + this.registerType('torch._C.MutableTypePtrHelper', class { + constructor(mutable_type_cache) { + this._mutable_type_cache = mutable_type_cache; + } + mapTypeToAliasTypeSet(type) { + if (this._mutable_type_cache) { + const result = this.mapTypeToBorrowedAliasTypeSet(type); + if (result) { + return result; + } + } + return this.mapTypeToAliasTypeSetImpl(type); + } + mapTypeToAliasTypeSetImpl(type) { + if (type instanceof torch.ListType || + type instanceof torch.DictType || + type instanceof torch.ClassType || + type instanceof torch.TensorType) { + return torch._C.AliasTypeSet([torch._C.unshapedType(type)]); + } + if (type instanceof torch.UnionType) { + const mutable_types = []; + for (const inner of type.expect(torch.UnionType).containedTypes()) { + const maybe_inner_types = this.mapTypeToAliasTypeSet(inner); + if (maybe_inner_types) { + throw new python.Error('Not implemented.'); + // mutable_types.insert(mutable_types.end(), (*maybe_inner_types).begin(), (*maybe_inner_types).end()); + } + } + if (mutable_types.length === 0) { + return null; + } + return mutable_types; + } + if (type instanceof torch.OptionalType) { + const inner = type.getElementType(); + return this.mapTypeToAliasTypeSet(inner); + } + if (type instanceof torch.AnyType) { + return [torch._C.AliasTypeSet([type])]; + } + if (type instanceof torch.FutureType) { + throw new python.Error('Not implemented.'); + /* if (auto maybe_mut_types = mapTypeToAliasTypeSet( + type->castRaw()->getElementType())) { + return {AliasTypeSet{FutureType::create(*toSingleType(*maybe_mut_types))}}; + } + return std::nullopt; */ + } + if (type instanceof torch.AwaitType) { + throw new python.Error('Not implemented.'); + /* + if (auto maybe_mut_types = mapTypeToAliasTypeSet( + type->castRaw()->getElementType())) { + return { + AliasTypeSet{AwaitType::create(*toSingleType(*maybe_mut_types))}}; + } + return std::nullopt; */ + } + if (type instanceof torch.TupleType) { + throw new python.Error('Not implemented.'); + /* + const mutable_types = [] + for (const inner : type.elements()) { + if (auto maybe_inner_types = this.mapTypeToAliasTypeSet(inner)) { + mutable_types.insert(mutable_types.end(), (*maybe_inner_types).begin(), (*maybe_inner_types).end()); + } + } + if (mutable_types.length === 0) { + return null; + } + return [new torch._C.AliasTypeSet([torch.TupleType.create(mutable_types)])]; + */ + } + return null; + } + }); + this.registerFunction('torch._C.isMutableTypeImpl', (type, mutable_type_cache) => { + if (type instanceof torch.TensorType || type instanceof torch.ListType || + type instanceof torch.ClassType || type instanceof torch.DictType) { + return true; + } + const helper = new torch._C.MutableTypePtrHelper(mutable_type_cache); + if (mutable_type_cache) { + return helper.mapTypeToBorrowedAliasTypeSet(type) !== null; + } + return helper.mapTypeToAliasTypeSet(type) !== null; + }); this.registerType('torch._C.AliasDb', class { - + constructor() { + this._writeIndex = new Map(); + } + static isMutableType(...args) { + if (args[0] instanceof torch.Type) { + const [type] = args; + return torch._C.isMutableTypeImpl(type, null); + } + if (args[0] instanceof torch.Value) { + const [value] = args; + return torch._C.AliasDb.isMutableType(value.type()); + } + throw new python.Error('Not implemented.'); + } + writesToAlias(/* n, vs */) { + /* + const writtenTo = this.getWrites(n); + if (writtenTo.length === 0) { + return false; + } + MemoryLocations locs; + for (const v of vs) { + auto it = elementMap_.find(v); + if (it != elementMap_.end()) { + const auto& vlocs = memoryDAG_->getMemoryLocations(it->second); + if (writtenTo.intersects(vlocs)) { + return true; + } + } + } + */ + return false; + } + writesToWildcard(n) { + if (!this._writeIndex.has(n)) { + return false; + } + const writes = this._writeIndex.get(n); + for (const pr of this._wildcardIndex) { + const [, wildcardElement] = pr; + if (writes.test(wildcardElement.index)) { + return true; + } + } + return false; + } }); this.registerFunction('torch._C.TORCH_INTERNAL_ASSERT', (cond) => { if (!cond) { @@ -7381,6 +7777,7 @@ python.Execution = class { this.registerType('torch.ClassType', class extends torch.Type { constructor(qualified_name, cu, is_module) { super('ClassType', typeof qualified_name === 'string' ? qualified_name : qualified_name.qualifiedName()); + this._name = typeof qualified_name === 'string' ? new torch._C.QualifiedName(qualified_name) : qualified_name; this._is_module = is_module; this._attributes = []; this._attributeTypes = []; @@ -7395,7 +7792,7 @@ python.Execution = class { return this.annotation_str; } name() { - return this._qualified_name.split('.').pop(); + return this._name; } is_module() { return this._is_module; @@ -7515,7 +7912,7 @@ python.Execution = class { bool float_found = false; bool complex_found = false; bool nonetype_found = false; - auto update_is_opt_flags = [&](const TypePtr& t) { + const update_is_opt_flags = [&](const TypePtr& t) { if (t == IntType::get()) { int_found = true; } else if (t == FloatType::get()) { @@ -7535,7 +7932,7 @@ python.Execution = class { return OptionalType::create(NumberType::get()); } if (union_type->containedTypes().size() == 2) { - auto not_none = union_type->containedTypes()[0] != NoneType::get() + const not_none = union_type->containedTypes()[0] != NoneType::get() ? union_type->containedTypes()[0] : union_type->containedTypes()[1]; return OptionalType::create(not_none); @@ -7582,7 +7979,7 @@ python.Execution = class { // return; } /* - auto get_supertype = [](const TypePtr& t1, const TypePtr& t2) -> std::optional { + const get_supertype = [](const TypePtr& t1, const TypePtr& t2) -> std::optional { // We don't want nested Optionals. Also, prematurely unifying to // `Optional` could prevent us from coalescing other types if ((t1->isSubtypeOf(*NoneType::get()) && !t2->isSubtypeOf(*NoneType::get())) @@ -8811,6 +9208,71 @@ python.Execution = class { } return list.join(''); } + aliasAnalysis() { + return this._alias_kind || 'CONSERVATIVE'; + } + setAliasAnalysis(v) { + this._alias_kind = v; + } + hasAnyAliasInfo() { + for (const arg of this.arguments) { + if (arg.alias_info !== null) { + return true; + } + } + for (const ret of this.returns) { + if (ret.alias_info !== null) { + return true; + } + } + return false; + } + }); + this.registerType('torch._C.SchemaInfo', class { + constructor(schema) { + this._schema = schema; + this._alias_maps_current = false; + this._has_init = false; + } + is_nondeterministic() { + if (this._schema.name === 'aten::dropout' && this._schema.overload === '') { + // + } + torch._C.nondeterministic_op_strings = torch._C.nondeterministic_op_strings || new Set([ + 'aten::dropout(Tensor input, float p, bool train) -> Tensor', + 'aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)', + 'aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor', + 'aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor', + 'aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor', + 'aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor', + 'aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)', + 'aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor', + 'aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor', + 'aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor', + 'aten::poisson(Tensor self, Generator? generator) -> Tensor', + 'aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor', + 'aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor', + 'aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor', + 'aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor', + 'aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor', + 'aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor', + 'aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor', + 'aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor', + 'aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor', + 'aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor', + 'aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor', + 'aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor' + ]); + if (torch._C.nondeterministic_op_strings.has(this._schema.__str__())) { + return true; + } + /* + const auto& op = c10::Dispatcher::singleton().findOp( + c10::OperatorName(schema_.name(), schema_.overload_name())); + return op && op->hasTag(at::Tag::nondeterministic_seeded); + */ + return false; + } }); this.registerFunction('torch._C.string_to_type_lut', () => { if (!torch._C.string_to_type_lut.basePythonTypes) { @@ -9389,7 +9851,7 @@ python.Execution = class { out.write('with '); out.write(fg.kind()); out.write(`_${i} = `); - out.write(fg.g('subsgraph')); + out.write(fg.g('Subgraph')); } return out; } @@ -9534,6 +9996,14 @@ python.Execution = class { } return this._op ? this._op.schema() : null; } + hasNamedInput(name) { + for (const argument of this.schema().arguments) { + if (argument.name === name) { + return true; + } + } + return false; + } matches(schema) { if (torch._C.isBlockListedSchema(schema)) { return false; @@ -9566,6 +10036,13 @@ python.Execution = class { } return true; } + maybeSchema() { + const op = this.maybeOperator(); + if (op) { + return op.schema(); + } + return null; + } maybeOperator() { if (!this._op) { const candidates = torch._C.getAllOperatorsFor(this.kind()); @@ -9588,6 +10065,77 @@ python.Execution = class { getOperation() { return this.getOperator().getOperation(this); } + isNondeterministic() { + const schema = this.maybeSchema(); + if (!this.kind().startsWith('aten::')) { + return false; + } + if (!schema) { + return false; + } + const schema_info = new torch._C.SchemaInfo(schema); + if (this.hasNamedInput('train')) { + throw new python.Error('Not Implemented.'); + // const value = constant_as(this.namedInput("train")); + // if (value) { + // schema_info.addArgumentValue('train', value); + // } + } + return schema_info.is_nondeterministic(); + } + hasSideEffects() { + switch (this._kind) { + case 'prim::PythonOp': + case 'prim::IgnoredPythonOp': + case 'prim::Print': + case 'prim::RaiseException': + case 'aten::warn': + case 'aten::save': + case 'aten::manual_seed': + case 'prim::AddStatValue': + case 'prim::TimePoint': + case 'prim::CallFunction': + case 'prim::CallMethod': + case 'prim::BailoutTemplate': + case 'prim::BailOut': + case 'prim::rpc_async': + case 'prim::rpc_sync': + case 'prim::rpc_remote': + case 'aten::wait': + case 'cuda::set_stream': + case 'cuda::_set_device': + case 'cuda::_current_device': + case 'cuda::synchronize': + case 'prim::Enter': + case 'prim::Exit': + return true; + default: + break; + } + const op = this.maybeOperator(); + if (!op) { + torch._C.TORCH_INTERNAL_ASSERT(this._kind.startsWith('prim::')); + return false; + } + if (this._kind.startsWith('prim::') || this._kind.startsWith('aten::') || this._kind.startsWith('cuda::')) { + torch._C.TORCH_INTERNAL_ASSERT( + op.aliasAnalysisKind() === 'INTERNAL_SPECIAL_CASE' || + op.aliasAnalysisKind() === 'FROM_SCHEMA' || + op.aliasAnalysisKind() === 'CONSERVATIVE'); + } + switch (op.aliasAnalysisKind()) { + case 'PURE_FUNCTION': + case 'FROM_SCHEMA': + case 'INTERNAL_SPECIAL_CASE': + return false; + case 'CONSERVATIVE': + return true; + default: + break; + } + torch._C.TORCH_INTERNAL_ASSERT(false); + return false; + } inputs() { return this._inputs; } @@ -9607,6 +10155,14 @@ python.Execution = class { } return this._outputs[0]; } + hasUses() { + for (const o of this.outputs()) { + if (o.uses().length > 0) { + return true; + } + } + return false; + } blocks() { return this._blocks; } @@ -9679,9 +10235,7 @@ python.Execution = class { return this; } dropInput(i) { - if (i >= this._inputs.length) { - throw new python.Error('Input index out of range.'); - } + torch._C.AT_ASSERT(i < this._inputs.length); const input_node = this._inputs[i]; const use_it = this.findUseForInput(i); input_node._uses.splice(use_it.offset, 1); @@ -9689,10 +10243,15 @@ python.Execution = class { return input_node; } eraseOutput(i) { + torch._C.AT_ASSERT(i < this._outputs.length); + // torch._C.AT_ASSERT(this._outputs[i].uses().length === 0); this._op = null; - const v = this._outputs[i]; + const n = this._outputs[i]; this._outputs.splice(i, 1); - this.owningGraph().freeValue(v); + this.owningGraph().freeValue(n); + for (let j = i; j < this._outputs.length; j++) { + this._outputs[j]._offset--; + } } eraseBlock(i) { this._op = null; @@ -9850,7 +10409,7 @@ python.Execution = class { out.write(' = '); if (this.kind() === 'prim::PythonOp') { throw new python.Error('Not implemented.'); - } else if (this.hasAttribute('subgraph') && groups) { + } else if (this.hasAttribute('Subgraph') && groups) { throw new python.Error('Not implemented.'); } else { out.write(this.kind()); @@ -9900,6 +10459,9 @@ python.Execution = class { uses() { return this._uses; } + hasUses() { + return this._uses.length > 0; + } hasDebugName() { return this._unique_name; } @@ -10339,8 +10901,9 @@ python.Execution = class { if (!src) { return; } - const program = this._cu.execution.parse(src.filename(), src.text_str(), null); - for (const stmt of program.body) { + const p = this._cu.execution.parse(src.filename(), src.text_str(), null); + this.parsePossibleVersionNumber(p); + for (const stmt of p.body) { if (stmt instanceof ast.ClassDef) { const name = `${qualifier}.${stmt.name}`; this._to_be_defined.set(name, stmt); @@ -10350,6 +10913,26 @@ python.Execution = class { } } } + parsePossibleVersionNumber(/* p */) { + } + parseImports(/* p */) { + } + LEGACY_import_methods(mod, src) { + const self = new torch._C.SimpleSelf(mod.type()); + const prefix = mod.type().name(); + const p = this._cu.execution.parse(src.filename(), src.text_str(), null); + this.parsePossibleVersionNumber(p); + this.parseImports(p); + const definitions = []; + const resolvers = []; + for (const def of p.body) { + if (def instanceof ast.FunctionDef) { + definitions.push(def); + resolvers.push(this); + } + } + this._cu.define(prefix, /*properties=*/[], /*propResolvers=*/[], definitions, resolvers, self); + } findFunction(name) { this.parseSourceIfNeeded(name.prefix()); const key = name.qualifiedName(); @@ -10598,13 +11181,15 @@ python.Execution = class { const [data, size] = reader_->getRecord(module_def.torchscript_debug_arena().key()); gen_ranges = std::make_shared(std::move(data), size); } - if (module_def.has_torchscript_arena()) { - const [data, size] = - reader_->getRecord(module_def.torchscript_arena().key()); - std::string data_str(static_cast(data.get()), size); - const src = std::make_shared(std::string(static_cast(data.get()), size), module_def.torchscript_arena().key(), 1, std::move(gen_ranges)); - source_importer_.LEGACY_import_methods(module, src); + */ + if (module_def.torchscript_arena) { + const filename = module_def.torchscript_arena.key; + const stream = this._reader.get_record(filename); + const data = stream.peek(); + const src = new torch._C.Source(data, filename); + this._source_importer.LEGACY_import_methods(module, src); } + /* if (module_def.has_get_state_attribute_id()) { LEGACY_moduleSetState(module, LEGACY_pickled_ivalues_.at(module_def.get_state_attribute_id())); } @@ -10676,8 +11261,9 @@ python.Execution = class { if (!reader.has_record(path)) { return null; } - const data = reader.get_record(path); - return new torch._C.Source(data.peek(), path); + const stream = reader.get_record(path); + const data = stream.peek(); + return new torch._C.Source(data, path); } }); this.registerType('torch._C.WithInsertPoint', class { @@ -10698,6 +11284,14 @@ python.Execution = class { this.next = next; this.value_table = new Map(); this.type_table = new Map(); + this.error_messages = new Map(); + } + setVariableTypeError(name, msg) { + let runner = this; + while (runner.next) { + runner = runner.next; + } + runner.error_messages.set(name, msg); } insertLoad(name, type) { const g = this.b.owningGraph(); @@ -12633,7 +13227,7 @@ python.Execution = class { return result; } }); - this.registerType('torch.jit.to_ir', class { + this.registerType('torch._C.to_ir', class { constructor(def, _resolver, self, method) { this.method = method; this.graph = method.graph(); @@ -12645,6 +13239,7 @@ python.Execution = class { this.environment_stack = null; this._def_stack = []; this._temp_name_count = 0; + torch._C.AT_ASSERT(this.resolver); this.pushFrame(this.graph.block(), true); if (self && def && def.args.args.length === 0) { throw new python.Error('Method must have a self argument.'); @@ -12652,7 +13247,7 @@ python.Execution = class { method.setSchema(this.emitDef(def, self, this.graph.block())); // torch._C.ReplaceOldOperatorsWithUpgraders(this.graph); torch._C.ConvertToSSA(this.graph); - // torch._C.CanonicalizeModifiedLoops(this.graph); + torch._C.CanonicalizeModifiedLoops(this.graph); torch._C.NormalizeOps(this.graph.block()); torch._C.runCleanupPasses(this.graph); } @@ -13492,7 +14087,7 @@ python.Execution = class { }; // properties for (let i = 0; i < definitions.length; i++) { - const fn = this.define(prefix, definitions[i], defResolvers[i], self, function_table, shouldMangle, 'method', operator_set_version); + const fn = this.define(prefix, definitions[i], defResolvers[i], self, function_table, shouldMangle, 'Method', operator_set_version); record_function(fn); } for (const [name, fn] of function_table) { @@ -13506,29 +14101,18 @@ python.Execution = class { return functions; } else if (args[1] instanceof ast.FunctionDef) { const [, def, resolver, self, function_table, shouldMangle, type, operator_set_version] = args; - let _resolver = resolver; - if (!self) { - _resolver = new torch._C.FunctionResolver(resolver, function_table); - } + const _resolver = self ? resolver : new torch._C.FunctionResolver(resolver, function_table); const creator = (method) => { - // let call_name = method.qualname().name(); - // if (self) { - // const atoms = method.qualname().atoms(); - // // TORCH_INTERNAL_ASSERT(atoms.size() >= 2); - // call_name = `${atoms.at(atoms.size() - 2)}.${atoms.at(atoms.size() - 1)}`; - // } - // this.call(call_name, def.range()); - return new torch.jit.to_ir(def, _resolver, self, method); + return new torch._C.to_ir(def, _resolver, self, method); }; - const name = prefix ? new torch._C.QualifiedName(prefix, def.name) : new torch._C.QualifiedName(def.name); + let name = prefix ? new torch._C.QualifiedName(prefix, def.name) : new torch._C.QualifiedName(def.name); + if (shouldMangle && this.find_function(name)) { + name = this.mangle(name); + } const graph = new torch.Graph(); graph.set_op_version(operator_set_version); const fn = new torch._C.GraphFunction(name, graph, creator); - fn.__ast__ = def; - if (shouldMangle && this.find_function(name)) { - // name = mangle(name); - throw new python.Error('Not implemented.'); - } + fn.__ast__ = def; // remove if (self) { if (type === 'hook') { self.getClassType().addForwardHook(fn); @@ -13563,6 +14147,53 @@ python.Execution = class { erase_loads_stores.run(graph); torch._C.TransformExits(graph); }); + this.registerFunction('torch._C.canonicalizeModifiedLoop', (/* n */) => { + /* + LoopView loop(n); + if (loop.loopType() != LoopView::ModifiedLoop) { + return; + } + const g = n.owningGraph(); + WithInsertPoint node_insert(n); + const zero = g.insertConstant(0); + const one = g.insertConstant(1); + const max_trip_count = loop.maxTripCount(); + const condition = g.insert(aten::gt, {max_trip_count, zero}); + loop.replaceMaxTripCount(g.insertConstant(std::numeric_limits::max())); + const inp_condition = toIValue(loop.inputCond()); + if (inp_condition == null || inp_condition.toBool() == false) { + condition = g.insert(aten::__and__, {condition, loop.inputCond()}); + } + loop.replaceInputCondition(condition); + n.addOutput().setType(IntType::get()); + WithInsertPoint loop_insert(loop.bodyBlock()); + n.addInput(zero); + const new_iter = loop.bodyBlock().addInput().setType(IntType::get()); + // unset unique name for jitter, its replacement does not have a name + loop.currentTripCount().setDebugName("").replaceAllUsesWith(new_iter); + const inc_iter = g.insert(aten::add, {new_iter, one}); + loop.bodyBlock().registerOutput(inc_iter); + const less_than_max_trip = g.insert(aten::lt, {inc_iter, max_trip_count}); + const loop_continue = loop.nextCond(); + const new_condition = + g.insert(aten::__and__, {less_than_max_trip, loop_continue}); + loop.bodyBlock().eraseOutput(0); + loop.bodyBlock().insertOutput(0, new_condition); + */ + }); + this.registerFunction('torch._C.canonicalizeModifiedLoops', (block) => { + for (const n of block.nodes()) { + for (const b of n.blocks()) { + torch._C.canonicalizeModifiedLoops(b); + } + if (n.kind() === 'prim::Loop') { + torch._C.canonicalizeModifiedLoop(n); + } + } + }); + this.registerFunction('torch._C.CanonicalizeModifiedLoops', (graph) => { + torch._C.canonicalizeModifiedLoops(graph.block()); + }); this.registerType('torch._C.MiniEnvironment', class { constructor(b, next) { this.next = next || null; @@ -13596,6 +14227,9 @@ python.Execution = class { this.registerType('torch._C.TypeEnvironment', class extends torch._C.MiniEnvironment { }); this.registerType('torch._C.ControlFlowLoadStores', class { + constructor() { + this.environment_stack = null; + } pushFrame(b) { this.environment_stack = new torch._C.TypeEnvironment(b, this.environment_stack); } @@ -14025,9 +14659,9 @@ python.Execution = class { } /* torch._C.eraseListLiterals(to_clean); - torch._C.LowerSimpleTuples(to_clean); - torch._C.ConstantPropagationImmutableTypes(to_clean); */ + torch._C.LowerSimpleTuples(to_clean); + // torch._C.ConstantPropagationImmutableTypes(to_clean); torch._C.ConstantPooling(to_clean); /* torch._C.CanonicalizeOutputs(to_clean); diff --git a/source/pytorch.js b/source/pytorch.js index 73260f9861..77caaed1e9 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -105,6 +105,7 @@ pytorch.Graph = class { } } } + const deleted = new Set(); const param_node = graph.param_node(); const self = param_node && param_node.outputs().length > 0 && param_node.outputs()[0].type() === module._c._type() ? param_node.outputs()[0] : null; if (self) { @@ -138,6 +139,7 @@ pytorch.Graph = class { for (const output of node.outputs()) { delattr(output); } + // deleted.add(node); node.destroy(); } } @@ -152,6 +154,7 @@ pytorch.Graph = class { output.identifier = output.debugName(); output.value = value; } + // deleted.add(node); node.destroy(); } } @@ -163,6 +166,7 @@ pytorch.Graph = class { const output = node.outputs()[i]; output.value = value[i]; } + // deleted.add(node); node.destroy(); } } @@ -172,6 +176,7 @@ pytorch.Graph = class { node.inputs().every((value) => typeof value.value === 'number' && typeof value.value === 'string' && typeof value.value === 'boolean') && node.outputs().every((value) => value.uses().every((use) => use.user.kind() !== 'prim::CallMethod'))) { node.outputs()[0].value = node.inputs().map((value) => value.value); + // deleted.add(node); node.destroy(); } } @@ -189,6 +194,9 @@ pytorch.Graph = class { this.outputs.push(new pytorch.Argument(identifier, [values.map(identifier)])); } for (const node of graph.nodes()) { + if (deleted.has(node)) { + continue; + } if (node === graph.param_node() || node === graph.return_node()) { continue; @@ -4051,6 +4059,7 @@ pytorch.Metadata = class { if (type.category) { schema.category = type.category; } + schema.setAliasAnalysis('FROM_SCHEMA'); const op = new torch._C.Operator(schema); registry.registerOperator(op); modules.add(type.name.split('::')[0]);