From 6c5fd060dea5242afe3ea904b0e9e69eada03cd1 Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 25 Mar 2024 16:54:10 +0800 Subject: [PATCH] wip --- tools/pnnx/src/pass_level2.cpp | 500 +++++------------- .../pnnx/src/pass_level5/fuse_slice_copy.cpp | 7 +- 2 files changed, 122 insertions(+), 385 deletions(-) diff --git a/tools/pnnx/src/pass_level2.cpp b/tools/pnnx/src/pass_level2.cpp index c0a6bc90291..4ffadd64c67 100644 --- a/tools/pnnx/src/pass_level2.cpp +++ b/tools/pnnx/src/pass_level2.cpp @@ -1123,14 +1123,17 @@ static bool is_alias_op(const Operator* op) static void functionize(Graph& graph) { + // graph.save("0.param", "0.bin"); + // 1. create shadow view/slice/select/... for each consumer - // 2. tag operand alias for view/slice/select/... output - // 3. scan inplace op, collect affacted alias - // 4. replace with non-inplace version, create copy op + // 2. replace inplace op, append copy + // 3. tag operand alias for view/slice/select/... output + // 4. scan inplace op, collect affacted alias // 5. look for any op after the inplace op with alias input // 6. collect ops on the chain back to alias // 7. move chain after copy op // 8. update all alias uses after copy op, retag alias + // 9. clear all alias tag // 1. create shadow view/slice/select/... for each consumer { @@ -1181,16 +1184,56 @@ static void functionize(Graph& graph) } } - // 2. tag operand alias for view/slice/select/... output + // graph.save("1.param", "1.bin"); + + // 2. replace inplace op, append copy { for (size_t i = 0; i < graph.ops.size(); i++) { Operator* op = graph.ops[i]; + if (op->type == "aten::copy_") + continue; + bool is_inplace_op = op->type.size() > 2 && op->type[op->type.size() - 2] != '_' && op->type[op->type.size() - 1] == '_'; - if (!is_alias_op(op) && !is_inplace_op) - // if (!is_alias_op(op)) + if (!is_inplace_op) + continue; + + // replace with non-inplace version, create copy op + op->type = op->type.substr(0, op->type.size() - 1); + + // append aten::copy_ + { + Operand* in0 = op->inputs[0]; + Operand* out0 = op->outputs[0]; + + Operator* op_copy = graph.new_operator_after("aten::copy_", op->name + "_copy", op); + Operand* copy_out = graph.new_operand(op->name + "_copy_out"); + + copy_out->type = out0->type; + copy_out->shape = out0->shape; + + op_copy->inputs.push_back(in0); + op_copy->inputs.push_back(out0); + in0->consumers.push_back(op_copy); + out0->consumers.push_back(op_copy); + + op_copy->outputs.push_back(copy_out); + copy_out->producer = op_copy; + } + } + } + + // graph.save("2.param", "2.bin"); + + // 3. tag operand alias for view/slice/select/... output + { + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (!is_alias_op(op) && op->type != "aten::copy_") continue; Operand* in = op->inputs[0]; @@ -1209,74 +1252,78 @@ static void functionize(Graph& graph) { x->params["__alias__"] = alias_index; } + + // fprintf(stderr, "operand %s is alias of %s\n", op->outputs[0]->name.c_str(), graph.operands[alias_index]->name.c_str()); } } - // 3. scan inplace op, collect affacted alias + // graph.save("3.param", "3.bin"); + + // 4. scan inplace copy op, collect affacted alias { for (size_t i = 0; i < graph.ops.size(); i++) { Operator* op = graph.ops[i]; - bool is_inplace_op = op->type.size() > 2 && op->type[op->type.size() - 2] != '_' && op->type[op->type.size() - 1] == '_'; - - if (!is_inplace_op) + if (op->type != "aten::copy_") continue; - // inplace op output always alias with the input - const int alias_index = op->outputs[0]->params.at("__alias__").i; - - // 4. replace with non-inplace version, create copy op - op->type = op->type.substr(0, op->type.size() - 1); + op->type = "aten::copy"; - { - Operator* op_copy = graph.new_operator_after("aten::copy", op->name + "_copy", op); - Operand* copy_out = graph.new_operand(op->name + "_copy_out"); + Operand* out0 = op->outputs[0]; - copy_out->type = in0->type; - copy_out->shape = in0->shape; + // inplace op output always alias with the input + const int alias_index = out0->params.at("__alias__").i; + Operand* alias_in0 = graph.operands[alias_index]; - op_copy->inputs.push_back(op->inputs[0]); - op_copy->inputs.push_back(op->outputs[0]); - op->inputs[0]->consumers.push_back(op_copy); - op->outputs[0]->consumers.push_back(op_copy); + // fprintf(stderr, "\n---> %s for %s\n", op->name.c_str(), alias_in0->name.c_str()); - op_copy->outputs.push_back(copy_out); - copy_out->producer = op_copy; - } + size_t i_advanced = 0; // 5. look for any op after the inplace op with alias input - for (size_t j = i + 2; j < graph.ops.size(); j++) + for (size_t j = i + 1; j < graph.ops.size(); j++) { Operator* op1 = graph.ops[j]; bool affacted = false; - Operator* op10 = 0; for (Operand* x : op1->inputs) { + if (x == alias_in0) + { + affacted = true; + break; + } + if (x->params.find("__alias__") == x->params.end()) continue; int alias_index_1 = x->params.at("__alias__").i; if (alias_index_1 == alias_index) { - op10 = x->producer; affacted = true; break; } } + // fprintf(stderr, "op %s affacted %d\n", op1->name.c_str(), affacted); + if (!affacted) continue; // 6. collect ops on the chain back to alias - std::vector chainsx_op_indexes; + std::set chainsx_op_indexes; { - int op10_index = std::find(graph.ops.begin(), graph.ops.end(), op10) - graph.ops.begin(); - chainsx_op_indexes.push_back(op10_index); + size_t op1_index = std::find(graph.ops.begin(), graph.ops.end(), op1) - graph.ops.begin(); + + if (op1_index < i - i_advanced) + { + chainsx_op_indexes.insert(op1_index); + // fprintf(stderr, "affacted op %s for %s\n", op1->name.c_str(), graph.operands[alias_index]->name.c_str()); + } + while (1) { - Operand* x = op10->inputs[0]; + Operand* x = op1->inputs[0]; if (x->params.find("__alias__") == x->params.end()) break; @@ -1284,391 +1331,82 @@ static void functionize(Graph& graph) if (alias_index_1 != alias_index) break; - op10 = x->producer; - int op10_index = std::find(graph.ops.begin(), graph.ops.end(), op10) - graph.ops.begin(); - chainsx_op_indexes.push_back(op10_index); + op1 = x->producer; + size_t op1_index = std::find(graph.ops.begin(), graph.ops.end(), op1) - graph.ops.begin(); + + if (op1_index < i - i_advanced) + { + chainsx_op_indexes.insert(op1_index); + // fprintf(stderr, "affacted op %s for %s chained\n", op1->name.c_str(), graph.operands[alias_index]->name.c_str()); + } } } // 7. move chain after copy op - for (size_t k = 0; k < chainsx_op_indexes.size(); k++) { - int doi = chainsx_op_indexes[k]; - - // fprintf(stderr, "move %s after %s\n", graph.ops[doi]->name.c_str(), graph.ops[i + 1 - k]->name.c_str()); - - for (int l = doi; l <= i + 1 - k; l++) + int k = 0; + for (size_t doi : chainsx_op_indexes) { - std::swap(graph.ops[l], graph.ops[l+1]); - } - } - - // 8. update all alias uses after copy op, retag alias - for (size_t k = i + 2; k < graph.ops.size(); k++) - { - Operator* op2 = graph.ops[k]; + doi -= k; + // fprintf(stderr, "---> move %s after %s\n", graph.ops[doi]->name.c_str(), graph.ops[i - i_advanced]->name.c_str()); - bool use_in0 = false; - for (size_t l = 0; l < op2->inputs.size(); l++) - { - if (op2->inputs[l] == in0) + for (size_t l = doi; l < i - i_advanced; l++) { - op2->inputs[l] = out0; - use_in0 = true; + std::swap(graph.ops[l], graph.ops[l+1]); } - } - if (use_in0) - { - in0->remove_consumer(op2); - out0->consumers.push_back(op2); + k += 1; } - } - } - } - } - - while (1) - { - bool matched = false; - - for (int i = (int)graph.ops.size() - 1; i >= 0; i--) - { - Operator* op = graph.ops[i]; - - if (op->type != "aten::slice" && op->type != "aten::select") - continue; - - Operand* out0 = op->outputs[0]; - - if (out0->consumers.size() == 1) - continue; - - matched = true; - - // slice/select output has multiple consumers - // create one slice/select for each consumer - for (size_t j = 1; j < out0->consumers.size(); j++) - { - Operator* op1 = out0->consumers[j]; - Operator* op_shadow = graph.new_operator_before(op->type, op->name + "_pnnxshadow_" + std::to_string(j), op1); - - Operand* shadow_out = graph.new_operand(op_shadow->name + "_out"); - - op_shadow->inputs = op->inputs; - op_shadow->params = op->params; - op_shadow->outputs.push_back(shadow_out); - - for (Operand* x : op->inputs) - { - x->consumers.push_back(op_shadow); + i_advanced += chainsx_op_indexes.size(); } - shadow_out->producer = op_shadow; - shadow_out->type = out0->type; - shadow_out->shape = out0->shape; - shadow_out->params = out0->params; - - shadow_out->consumers.push_back(op1); - - for (size_t k = 0; k < op1->inputs.size(); k++) - { - if (op1->inputs[k] == out0) - op1->inputs[k] = shadow_out; - } - } - - out0->consumers.resize(1); - } - - if (!matched) - break; - } - - // graph.save("debug01.param", "debug01.bin"); - - while (1) - { - bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) - { - Operator* op = graph.ops[i]; - - bool is_inplace_op = op->type.size() > 2 && op->type[op->type.size() - 2] != '_' && op->type[op->type.size() - 1] == '_'; - if (!is_inplace_op) - continue; - - // replace inplace op with non-inplace version - op->type = op->type.substr(0, op->type.size() - 1); - - // find in0 from slice / select chain - Operand* in0 = op->inputs[0]; - while (in0->producer->type == "aten::slice" || in0->producer->type == "aten::select") - { - in0 = in0->producer->inputs[0]; - } - - if (op->type == "aten::copy") - continue; - - if (op->outputs[0]->consumers.size() != 0) - continue; - - // fprintf(stderr, "matched\n"); - - matched = true; - - // append copy for inplace op - Operator* op_copy = graph.new_operator_after("aten::copy", op->name + "_copy", op); - Operand* copy_out = graph.new_operand(op->name + "_copy_out"); - - copy_out->type = in0->type; - copy_out->shape = in0->shape; - - op_copy->inputs.push_back(op->inputs[0]); - op_copy->inputs.push_back(op->outputs[0]); - op->inputs[0]->consumers.push_back(op_copy); - op->outputs[0]->consumers.push_back(op_copy); - - op_copy->outputs.push_back(copy_out); - copy_out->producer = op_copy; - - break; - } - - if (!matched) - break; - } - - graph.save("debug01.param", "debug01.bin"); - - // while (1) - { - // bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) - { - Operator* op = graph.ops[i]; - - if (op->type != "aten::copy") - continue; - - // find in0 from slice / select chain - // Operand* in0 = op->inputs[0]; - // std::set op_slice_chain; - // while (in0->producer->type == "aten::slice" || in0->producer->type == "aten::select") - // { - // op_slice_chain.insert(in0->producer); - // in0 = in0->producer->inputs[0]; - // } - - std::unordered_set producer_ops; - collect_producer_ops(op->outputs[0], producer_ops); - - fprintf(stderr, "%s %s\n", op->type.c_str(), op->name.c_str()); - - // fprintf(stderr, "matched\n"); - - // matched = true; - -#if 0 - // decouple slice / select chain - for (size_t j = 0; j < op_slice_chain.size(); j++) - { - if (op_slice_chain[j]->outputs[0]->consumers.size() == 1) - continue; - - std::vector other_consumers; - for (Operator* x : op_slice_chain[j]->outputs[0]->consumers) - { - if (x == op || (j + 1 < op_slice_chain.size() && x == op_slice_chain[j + 1])) - continue; - - other_consumers.push_back(x); - } - - // append 0...j chain part for this consumer x - for (size_t k = 0; k < other_consumers.size(); k++) + // 8. update all alias uses after copy op, retag alias + out0->params.erase("__alias__"); + const int new_alias_index = std::find(graph.operands.begin(), graph.operands.end(), out0) - graph.operands.begin(); + for (size_t k = i - i_advanced + 1; k < graph.ops.size(); k++) { - Operator* other_consumer = other_consumers[k]; + Operator* op2 = graph.ops[k]; - Operator* op_cursor = op; - for (size_t jj = 0; jj <= j; jj++) + // bool use_in0 = false; + for (size_t l = 0; l < op2->inputs.size(); l++) { - Operator* op_slice_shadow = graph.new_operator_after(op_slice_chain[jj]->type, op_slice_chain[jj]->name + "_shadow_" + std::to_string(k) + "_" + std::to_string(jj), op_cursor); - - op_slice_shadow->params = op_slice_chain[jj]->params; - - Operand* slice_in = jj == 0 ? in0 : op_cursor->outputs[0]; - - op_slice_shadow->inputs.push_back(slice_in); - slice_in->consumers.push_back(op_slice_shadow); - if (op_slice_chain[jj]->inputs.size() > 1) + if (op2->inputs[l] == alias_in0) { - for (size_t kk = 1; kk < op_slice_chain[jj]->inputs.size(); kk++) - { - Operand* x = op_slice_chain[jj]->inputs[kk]; - op_slice_shadow->inputs.push_back(x); - x->consumers.push_back(op_slice_shadow); - } - } - - Operand* slice_out = graph.new_operand(op_slice_shadow->name + "_out"); - - slice_out->producer = op_slice_shadow; - - op_slice_shadow->outputs.push_back(slice_out); + // fprintf(stderr, "---> replace %s input %s to %s\n", op2->name.c_str(), op2->inputs[l]->name.c_str(), out0->name.c_str()); - op_cursor = op_slice_shadow; + op2->inputs[l] = out0; + alias_in0->remove_consumer(op2); + out0->consumers.push_back(op2); + } } - op_slice_chain[j]->outputs[0]->remove_consumer(other_consumer); - - Operand* slice_out = op_cursor->outputs[0]; - slice_out->consumers.push_back(other_consumer); - for (size_t kk = 0; kk < other_consumer->inputs.size(); kk++) + for (Operand* x : op2->outputs) { - if (other_consumer->inputs[kk] == op_slice_chain[j]->outputs[0]) + if (x->params.find("__alias__") != x->params.end() && x->params.at("__alias__").i == alias_index) { - other_consumer->inputs[kk] = slice_out; + x->params["__alias__"] = new_alias_index; } } - } - } -#endif - - // find all operators that depends on in0 before copy but not in this slice / select chain - { - // fprintf(stderr, "%s in0 %s\n", op->name.c_str(), in0->name.c_str()); - std::vector dependant_op_indexes; - - std::set dependants; - // dependants.insert(in0); - - int start = 0;//std::find(graph.ops.begin(), graph.ops.end(), in0->producer) - graph.ops.begin(); - int end = std::find(graph.ops.begin(), graph.ops.end(), op) - graph.ops.begin(); - for (int j = start; j < end; j++) - { - Operator* op1 = graph.ops[j]; - - if (op1->type == "prim::Constant") - continue; - - // if (op1->type == "aten::copy") - // { - // // skip previous inplace chain, restart scanning - // dependant_op_indexes.clear(); - // dependants.clear(); - // // dependants.insert(in0); - // continue; - // } - - if (producer_ops.find(op1) != producer_ops.end()) - continue; - - // bool is_dependant_op = false; - // for (Operand* x : op1->inputs) - // { - // if (dependants.find(x) != dependants.end()) - // { - // is_dependant_op = true; - // break; - // } - // } - // - // if (!is_dependant_op) - // continue; - - // if (std::find(op1->inputs.begin(), op1->inputs.end(), in0) != op1->inputs.end() && op1->type != "aten::slice" && op1->type != "aten::select") - // { - // // move from slice / select only - // continue; - // } - - dependant_op_indexes.push_back(j); - - for (Operand* x : op1->outputs) - { - dependants.insert(x); - } } - // move dependant ops after op_copy - // for (int j = (int)dependant_op_indexes.size() - 1; j >= 0; j--) - for (int j = 0; j < (int)dependant_op_indexes.size(); j++) - { - int doi = dependant_op_indexes[(int)dependant_op_indexes.size() - 1 - j]; - - fprintf(stderr, "move %s after %s\n", graph.ops[doi]->name.c_str(), graph.ops[end - j]->name.c_str()); - - for (int k = doi; k <= end - j; k++) - { - std::swap(graph.ops[k], graph.ops[k+1]); - } - } + // rewind to the updated copy operator + j -= chainsx_op_indexes.size(); } - - // break; } - - // if (!matched) - // break; } - graph.save("debug02.param", "debug02.bin"); + // graph.save("4.param", "4.bin"); - for (size_t i = 0; i < graph.ops.size(); i++) + // 9. clear all alias tag { - Operator* op = graph.ops[i]; - - if (op->type != "aten::copy") - continue; - - if (op->outputs[0]->consumers.size() != 0) - continue; - - // aten::slice 5 1 in0 .... a - // aten::slice 5 1 a .... b - // aten::copy 2 1 b in1 out - - // aten::select 3 1 in0 .... a - // aten::copy 2 1 a in1 out - - // find in0 from slice / select chain - Operand* in0 = op->inputs[0]; - while (in0->producer->type == "aten::slice" || in0->producer->type == "aten::select") + for (Operand* x : graph.operands) { - in0 = in0->producer->inputs[0]; - } - - // replace all the following uses of in0 with out - Operand* out0 = op->outputs[0]; - out0->shape = in0->shape; - for (size_t j = i; j < graph.ops.size(); j++) - { - Operator* op2 = graph.ops[j]; - - bool use_in0 = false; - for (size_t k = 0; k < op2->inputs.size(); k++) - { - if (op2->inputs[k] == in0) - { - op2->inputs[k] = out0; - use_in0 = true; - } - } - - if (use_in0) - { - in0->remove_consumer(op2); - out0->consumers.push_back(op2); - } + x->params.erase("__alias__"); } } - graph.save("debug03.param", "debug03.bin"); - } void pass_level2(Graph& g) diff --git a/tools/pnnx/src/pass_level5/fuse_slice_copy.cpp b/tools/pnnx/src/pass_level5/fuse_slice_copy.cpp index ab5dc6996e4..a106e6714cf 100644 --- a/tools/pnnx/src/pass_level5/fuse_slice_copy.cpp +++ b/tools/pnnx/src/pass_level5/fuse_slice_copy.cpp @@ -91,10 +91,10 @@ void fuse_slice_copy(Graph& graph) for (size_t j = 0; j < x->inputs.size(); j++) { if (x->inputs[j] == out) - x->inputs[j] = op->inputs[1]; + x->inputs[j] = op->inputs[0]; } - op->inputs[1]->consumers.push_back(x); + op->inputs[0]->consumers.push_back(x); } op->inputs[0]->remove_consumer(op); @@ -121,8 +121,7 @@ void fuse_slice_copy(Graph& graph) op->type = "Tensor.slice_copy"; - // insert clone before any slices - // Operator* op_clone = graph.new_operator_before("Tensor.clone", op->name + "_ncnnclone", top_sop); + // insert clone just after the producer Operator* op_clone = graph.new_operator_after("Tensor.clone", op->name + "_ncnnclone", top_sop->inputs[0]->producer); Operand* clone_out = graph.new_operand(op->name + "_ncnnclone_out");