diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index cebc5f42aa5..2fc9bf37757 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -303,6 +303,10 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_fft_fftn.cpp pass_level2/nn_quantized_FloatFunctional.cpp + + pass_level2/nn_GRU.cpp + pass_level2/nn_LSTM.cpp + pass_level2/nn_RNN.cpp ) set(pnnx_pass_level3_SRCS diff --git a/tools/pnnx/src/load_onnx.cpp b/tools/pnnx/src/load_onnx.cpp index c7378468222..36624d916bd 100644 --- a/tools/pnnx/src/load_onnx.cpp +++ b/tools/pnnx/src/load_onnx.cpp @@ -545,85 +545,20 @@ int load_onnx(const std::string& onnxpath, Graph& pnnx_graph, onnx2pnnx::ModelStat oldstat = onnx2pnnx::get_model_stat(model); - fprintf(stderr, "%-34s", "inline_containers ... "); + double t0 = 0; + double t1 = 0; - double t0 = get_current_time(); + int inlined = 0; - onnx2pnnx::inline_containers(model); - - double t1 = get_current_time(); - - fprintf(stderr, "%8.2fms\n", t1 - t0); - - fprintf(stderr, "%-34s", "eliminate_noop ... "); - - t0 = get_current_time(); - - onnx2pnnx::eliminate_noop(model); - - t1 = get_current_time(); - - fprintf(stderr, "%8.2fms\n", t1 - t0); - - fprintf(stderr, "%-34s", "fold_constants ... "); - - t0 = get_current_time(); - - onnx2pnnx::fold_constants(model, input_shapes, input_types, input_shapes2, input_types2); - - t1 = get_current_time(); - - fprintf(stderr, "%8.2fms\n", t1 - t0); - - fprintf(stderr, "%-34s", "canonicalize ... "); - - t0 = get_current_time(); - - onnx2pnnx::canonicalize(model); - - t1 = get_current_time(); - - fprintf(stderr, "%8.2fms\n", t1 - t0); - - fprintf(stderr, "%-34s", "shape_inference ... "); - - t0 = get_current_time(); - - onnx2pnnx::shape_inference(model, input_shapes, input_types, input_shapes2, input_types2); - - t1 = get_current_time(); - - fprintf(stderr, "%8.2fms\n", t1 - t0); - - fprintf(stderr, "%-34s", "fold_constants_dynamic_shape ... "); - - t0 = get_current_time(); - - onnx2pnnx::fold_constants_dynamic_shape(model, input_shapes, input_types); - - t1 = get_current_time(); - - fprintf(stderr, "%8.2fms\n", t1 - t0); - - fprintf(stderr, "%-34s", "inline_if_graph ... "); - - t0 = get_current_time(); - - int inlined = onnx2pnnx::inline_if_graph(model); - - t1 = get_current_time(); - - fprintf(stderr, "%8.2fms\n", t1 - t0); - - while (inlined) + do { fprintf(stderr, "%-34s", "inline_containers ... "); - double t0 = get_current_time(); + t0 = get_current_time(); onnx2pnnx::inline_containers(model); - double t1 = get_current_time(); + t1 = get_current_time(); fprintf(stderr, "%8.2fms\n", t1 - t0); @@ -686,7 +621,8 @@ int load_onnx(const std::string& onnxpath, Graph& pnnx_graph, t1 = get_current_time(); fprintf(stderr, "%8.2fms\n", t1 - t0); - } + + } while (inlined); fprintf(stderr, "%-34s", "fuse_constant_as_attribute ... "); diff --git a/tools/pnnx/src/pass_level2/nn_GRU.cpp b/tools/pnnx/src/pass_level2/nn_GRU.cpp new file mode 100644 index 00000000000..b5066a96096 --- /dev/null +++ b/tools/pnnx/src/pass_level2/nn_GRU.cpp @@ -0,0 +1,599 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" +#include + +namespace pnnx { + +class nn_GRU_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +GRU gru 3 1 input W R out %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 1 0 out1 +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.GRU"; + } + + const char* name_str() const + { + return "gru"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + if (captured_params.find("gru.hidden_size") == captured_params.end()) + return false; + + const int hidden_size = captured_params.at("gru.hidden_size").i; + + std::string direction = "forward"; + if (captured_params.find("gru.direction") != captured_params.end()) + { + direction = captured_params.at("gru.direction").s; + } + + if (direction != "forward" && direction != "bidirectional") + return false; + + const int num_directions = direction == "bidirectional" ? 2 : 1; + + if (captured_params.find("gru.activations") != captured_params.end()) + { + const std::vector& acts = captured_params.at("gru.activations").as; + + if (num_directions == 1) + { + if (acts != std::vector{"Sigmoid", "Tanh"}) + return false; + } + else // if (num_directions == 2) + { + if (acts != std::vector{"Sigmoid", "Tanh", "Sigmoid", "Tanh"}) + return false; + } + } + + if (captured_params.find("axes") != captured_params.end()) + { + if (captured_params.at("axes").type == 2 && captured_params.at("axes").i != 1) + return false; + + if (captured_params.at("axes").type == 5 && captured_params.at("axes").ai != std::vector{1}) + return false; + } + + const auto& W = captured_attrs.at("W.data"); + const auto& R = captured_attrs.at("R.data"); + + if (W.shape.size() != 3 || W.shape[0] != num_directions || W.shape[1] != 3 * hidden_size) + return false; + + if (R.shape.size() != 3 || R.shape[0] != num_directions || R.shape[1] != 3 * hidden_size || R.shape[2] != hidden_size) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + std::string direction = "forward"; + if (captured_params.find("gru.direction") != captured_params.end()) + { + direction = captured_params.at("gru.direction").s; + } + + const auto& W = captured_attrs.at("W.data"); + const auto& R = captured_attrs.at("R.data"); + + bool batch_first = false; + if (captured_params.find("gru.layout") != captured_params.end()) + { + const int layout = captured_params.at("gru.layout").i; + batch_first = layout == 1; + } + + const int hidden_size = captured_params.at("gru.hidden_size").i; + + const int input_size = W.shape[2]; + + op->params["input_size"] = input_size; + op->params["hidden_size"] = hidden_size; + op->params["num_layers"] = 1; + op->params["bias"] = false; + op->params["batch_first"] = batch_first; + op->params["bidirectional"] = direction == "bidirectional" ? true : false; + + // split W R and reorder URN to RUN + auto W_data = W.get_float32_data(); + auto R_data = R.get_float32_data(); + + std::vector W2(3 * hidden_size * input_size); + { + const int weight_data_size_g = hidden_size * input_size; + + const float* uptr = (const float*)W_data.data(); + const float* rptr = (const float*)W_data.data() + weight_data_size_g; + const float* nptr = (const float*)W_data.data() + weight_data_size_g * 2; + + float* w_rptr = (float*)W2.data(); + float* w_uptr = (float*)W2.data() + weight_data_size_g; + float* w_nptr = (float*)W2.data() + weight_data_size_g * 2; + + memcpy(w_rptr, rptr, weight_data_size_g * sizeof(float)); + memcpy(w_uptr, uptr, weight_data_size_g * sizeof(float)); + memcpy(w_nptr, nptr, weight_data_size_g * sizeof(float)); + } + + std::vector R2(3 * hidden_size * hidden_size); + { + const int weight_data_size_g = hidden_size * hidden_size; + + const float* uptr = (const float*)R_data.data(); + const float* rptr = (const float*)R_data.data() + weight_data_size_g; + const float* nptr = (const float*)R_data.data() + weight_data_size_g * 2; + + float* w_rptr = (float*)R2.data(); + float* w_uptr = (float*)R2.data() + weight_data_size_g; + float* w_nptr = (float*)R2.data() + weight_data_size_g * 2; + + memcpy(w_rptr, rptr, weight_data_size_g * sizeof(float)); + memcpy(w_uptr, uptr, weight_data_size_g * sizeof(float)); + memcpy(w_nptr, nptr, weight_data_size_g * sizeof(float)); + } + + if (direction == "bidirectional") + { + op->attrs["weight_ih_l0"] = Attribute({3 * hidden_size, input_size}, W2); + op->attrs["weight_hh_l0"] = Attribute({3 * hidden_size, hidden_size}, R2); + + std::vector W2R(3 * hidden_size * input_size); + { + const int weight_data_size_g = hidden_size * input_size; + + const float* uptr = (const float*)W_data.data() + weight_data_size_g * 3; + const float* rptr = (const float*)W_data.data() + weight_data_size_g * 4; + const float* nptr = (const float*)W_data.data() + weight_data_size_g * 5; + + float* w_rptr = (float*)W2R.data(); + float* w_uptr = (float*)W2R.data() + weight_data_size_g; + float* w_nptr = (float*)W2R.data() + weight_data_size_g * 2; + + memcpy(w_rptr, rptr, weight_data_size_g * sizeof(float)); + memcpy(w_uptr, uptr, weight_data_size_g * sizeof(float)); + memcpy(w_nptr, nptr, weight_data_size_g * sizeof(float)); + } + + std::vector R2R(3 * hidden_size * hidden_size); + { + const int weight_data_size_g = hidden_size * hidden_size; + + const float* uptr = (const float*)R_data.data() + weight_data_size_g * 3; + const float* rptr = (const float*)R_data.data() + weight_data_size_g * 4; + const float* nptr = (const float*)R_data.data() + weight_data_size_g * 5; + + float* w_rptr = (float*)R2R.data(); + float* w_uptr = (float*)R2R.data() + weight_data_size_g; + float* w_nptr = (float*)R2R.data() + weight_data_size_g * 2; + + memcpy(w_rptr, rptr, weight_data_size_g * sizeof(float)); + memcpy(w_uptr, uptr, weight_data_size_g * sizeof(float)); + memcpy(w_nptr, nptr, weight_data_size_g * sizeof(float)); + } + + op->attrs["weight_ih_l0_reverse"] = Attribute({3 * hidden_size, input_size}, W2R); + op->attrs["weight_hh_l0_reverse"] = Attribute({3 * hidden_size, hidden_size}, R2R); + } + else + { + op->attrs["weight_ih_l0"] = Attribute({3 * hidden_size, input_size}, W2); + op->attrs["weight_hh_l0"] = Attribute({3 * hidden_size, hidden_size}, R2); + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_GRU_onnx, 10) + +class nn_GRU_onnx_B : public nn_GRU_onnx +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +GRU gru 4 1 input W R B out %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 1 0 out1 +)PNNXIR"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + if (!nn_GRU_onnx::match(captured_params, captured_attrs)) + return false; + + const int hidden_size = captured_params.at("gru.hidden_size").i; + + std::string direction = "forward"; + if (captured_params.find("gru.direction") != captured_params.end()) + { + direction = captured_params.at("gru.direction").s; + } + + const int num_directions = direction == "bidirectional" ? 2 : 1; + + const auto& B = captured_attrs.at("B.data"); + + if (B.shape.size() != 2 || B.shape[0] != num_directions || B.shape[1] != 6 * hidden_size) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + nn_GRU_onnx::write(op, captured_params, captured_attrs); + + const auto& B = captured_attrs.at("B.data"); + + bool has_bias = false; + for (auto b : B.get_float32_data()) + { + if (b != 0.f) + { + has_bias = true; + break; + } + } + + op->params["bias"] = has_bias; + + if (has_bias) + { + const int hidden_size = captured_params.at("gru.hidden_size").i; + + // split B and reorder URN to RUN + auto B_data = B.get_float32_data(); + + std::vector B2(3 * hidden_size); + std::vector B3(3 * hidden_size); + { + const float* uptr = (const float*)B_data.data(); + const float* rptr = (const float*)B_data.data() + hidden_size; + const float* nptr = (const float*)B_data.data() + hidden_size * 2; + + float* w_rptr = (float*)B2.data(); + float* w_uptr = (float*)B2.data() + hidden_size; + float* w_nptr = (float*)B2.data() + hidden_size * 2; + + memcpy(w_rptr, rptr, hidden_size * sizeof(float)); + memcpy(w_uptr, uptr, hidden_size * sizeof(float)); + memcpy(w_nptr, nptr, hidden_size * sizeof(float)); + } + { + const float* uptr = (const float*)B_data.data() + hidden_size * 3; + const float* rptr = (const float*)B_data.data() + hidden_size * 4; + const float* nptr = (const float*)B_data.data() + hidden_size * 5; + + float* w_rptr = (float*)B3.data(); + float* w_uptr = (float*)B3.data() + hidden_size; + float* w_nptr = (float*)B3.data() + hidden_size * 2; + + memcpy(w_rptr, rptr, hidden_size * sizeof(float)); + memcpy(w_uptr, uptr, hidden_size * sizeof(float)); + memcpy(w_nptr, nptr, hidden_size * sizeof(float)); + } + + std::string direction = "forward"; + if (captured_params.find("gru.direction") != captured_params.end()) + { + direction = captured_params.at("gru.direction").s; + } + + if (direction == "bidirectional") + { + op->attrs["bias_ih_l0"] = Attribute({3 * hidden_size}, B2); + op->attrs["bias_hh_l0"] = Attribute({3 * hidden_size}, B3); + + std::vector B2R(3 * hidden_size); + std::vector B3R(3 * hidden_size); + { + const float* uptr = (const float*)B_data.data() + hidden_size * 6; + const float* rptr = (const float*)B_data.data() + hidden_size * 7; + const float* nptr = (const float*)B_data.data() + hidden_size * 8; + + float* w_rptr = (float*)B2R.data(); + float* w_uptr = (float*)B2R.data() + hidden_size; + float* w_nptr = (float*)B2R.data() + hidden_size * 2; + + memcpy(w_rptr, rptr, hidden_size * sizeof(float)); + memcpy(w_uptr, uptr, hidden_size * sizeof(float)); + memcpy(w_nptr, nptr, hidden_size * sizeof(float)); + } + { + const float* uptr = (const float*)B_data.data() + hidden_size * 9; + const float* rptr = (const float*)B_data.data() + hidden_size * 10; + const float* nptr = (const float*)B_data.data() + hidden_size * 11; + + float* w_rptr = (float*)B3R.data(); + float* w_uptr = (float*)B3R.data() + hidden_size; + float* w_nptr = (float*)B3R.data() + hidden_size * 2; + + memcpy(w_rptr, rptr, hidden_size * sizeof(float)); + memcpy(w_uptr, uptr, hidden_size * sizeof(float)); + memcpy(w_nptr, nptr, hidden_size * sizeof(float)); + } + + op->attrs["bias_ih_l0_reverse"] = Attribute({3 * hidden_size}, B2R); + op->attrs["bias_hh_l0_reverse"] = Attribute({3 * hidden_size}, B3R); + } + else + { + op->attrs["bias_ih_l0"] = Attribute({3 * hidden_size}, B2); + op->attrs["bias_hh_l0"] = Attribute({3 * hidden_size}, B3); + } + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_GRU_onnx_B, 10) + +class nn_GRU_onnx_1 : public nn_GRU_onnx +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 8 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +GRU gru 4 2 input W R initial_h out outh %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 2 0 out1 outh +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_GRU_onnx_1, 10) + +class nn_GRU_onnx_B1 : public nn_GRU_onnx_B +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 9 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +GRU gru 5 2 input W R B initial_h out outh %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 2 0 out1 outh +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_GRU_onnx_B1, 10) + +class nn_GRU_onnx_2 : public nn_GRU_onnx +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +GRU gru 4 1 input W R initial_h out %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 1 0 out1 +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_GRU_onnx_2, 10) + +class nn_GRU_onnx_B2 : public nn_GRU_onnx_B +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +GRU gru 5 1 input W R B initial_h out %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 1 0 out1 +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_GRU_onnx_B2, 10) + +class nn_GRU_onnx_3 : public nn_GRU_onnx +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +GRU gru 3 1 input W R out %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 1 0 out2 +)PNNXIR"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + if (!nn_GRU_onnx::match(captured_params, captured_attrs)) + return false; + + if (captured_params.at("reshape.shape").ai != std::vector{0, 0, -1}) + return false; + + return true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_GRU_onnx_3, 10) + +class nn_GRU_onnx_B3 : public nn_GRU_onnx_B +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +GRU gru 4 1 input W R B out %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 1 0 out2 +)PNNXIR"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + if (!nn_GRU_onnx_B::match(captured_params, captured_attrs)) + return false; + + if (captured_params.at("reshape.shape").ai != std::vector{0, 0, -1}) + return false; + + return true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_GRU_onnx_B3, 10) + +class nn_GRU_onnx_4 : public nn_GRU_onnx_3 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 9 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +GRU gru 4 2 input W R initial_h out outh %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 2 0 out2 outh +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_GRU_onnx_4, 10) + +class nn_GRU_onnx_B4 : public nn_GRU_onnx_B3 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 10 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +GRU gru 5 2 input W R B initial_h out outh %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 2 0 out2 outh +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_GRU_onnx_B4, 10) + +class nn_GRU_onnx_5 : public nn_GRU_onnx_3 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +GRU gru 4 1 input W R initial_h out %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 1 0 out2 +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_GRU_onnx_5, 10) + +class nn_GRU_onnx_B5 : public nn_GRU_onnx_B3 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +GRU gru 5 1 input W R B initial_h out %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 1 0 out2 +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_GRU_onnx_B5, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/nn_LSTM.cpp b/tools/pnnx/src/pass_level2/nn_LSTM.cpp new file mode 100644 index 00000000000..a61961943b1 --- /dev/null +++ b/tools/pnnx/src/pass_level2/nn_LSTM.cpp @@ -0,0 +1,632 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" +#include + +namespace pnnx { + +class nn_LSTM_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +LSTM lstm 3 1 input W R out %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 1 0 out1 +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.LSTM"; + } + + const char* name_str() const + { + return "lstm"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + if (captured_params.find("lstm.hidden_size") == captured_params.end()) + return false; + + const int hidden_size = captured_params.at("lstm.hidden_size").i; + + std::string direction = "forward"; + if (captured_params.find("lstm.direction") != captured_params.end()) + { + direction = captured_params.at("lstm.direction").s; + } + + if (direction != "forward" && direction != "bidirectional") + return false; + + const int num_directions = direction == "bidirectional" ? 2 : 1; + + if (captured_params.find("lstm.activations") != captured_params.end()) + { + const std::vector& acts = captured_params.at("lstm.activations").as; + + if (num_directions == 1) + { + if (acts != std::vector{"Sigmoid", "Tanh", "Tanh"}) + return false; + } + else // if (num_directions == 2) + { + if (acts != std::vector{"Sigmoid", "Tanh", "Tanh", "Sigmoid", "Tanh", "Tanh"}) + return false; + } + } + + if (captured_params.find("axes") != captured_params.end()) + { + if (captured_params.at("axes").type == 2 && captured_params.at("axes").i != 1) + return false; + + if (captured_params.at("axes").type == 5 && captured_params.at("axes").ai != std::vector{1}) + return false; + } + + const auto& W = captured_attrs.at("W.data"); + const auto& R = captured_attrs.at("R.data"); + + if (W.shape.size() != 3 || W.shape[0] != num_directions || W.shape[1] != 4 * hidden_size) + return false; + + if (R.shape.size() != 3 || R.shape[0] != num_directions || R.shape[1] != 4 * hidden_size || R.shape[2] != hidden_size) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + std::string direction = "forward"; + if (captured_params.find("lstm.direction") != captured_params.end()) + { + direction = captured_params.at("lstm.direction").s; + } + + const auto& W = captured_attrs.at("W.data"); + const auto& R = captured_attrs.at("R.data"); + + bool batch_first = false; + if (captured_params.find("lstm.layout") != captured_params.end()) + { + const int layout = captured_params.at("lstm.layout").i; + batch_first = layout == 1; + } + + const int hidden_size = captured_params.at("lstm.hidden_size").i; + + const int input_size = W.shape[2]; + + op->params["input_size"] = input_size; + op->params["hidden_size"] = hidden_size; + op->params["num_layers"] = 1; + op->params["bias"] = false; + op->params["batch_first"] = batch_first; + op->params["bidirectional"] = direction == "bidirectional" ? true : false; + op->params["proj_size"] = 0; + + // split W R and reorder IOFG to IFGO + auto W_data = W.get_float32_data(); + auto R_data = R.get_float32_data(); + + std::vector W2(4 * hidden_size * input_size); + { + const int weight_data_size_g = hidden_size * input_size; + + const float* iptr = (const float*)W_data.data(); + const float* optr = (const float*)W_data.data() + weight_data_size_g; + const float* fptr = (const float*)W_data.data() + weight_data_size_g * 2; + const float* gptr = (const float*)W_data.data() + weight_data_size_g * 3; + + float* w_iptr = (float*)W2.data(); + float* w_fptr = (float*)W2.data() + weight_data_size_g; + float* w_gptr = (float*)W2.data() + weight_data_size_g * 2; + float* w_optr = (float*)W2.data() + weight_data_size_g * 3; + + memcpy(w_iptr, iptr, weight_data_size_g * sizeof(float)); + memcpy(w_fptr, fptr, weight_data_size_g * sizeof(float)); + memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); + memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); + } + + std::vector R2(4 * hidden_size * hidden_size); + { + const int weight_data_size_g = hidden_size * hidden_size; + + const float* iptr = (const float*)R_data.data(); + const float* optr = (const float*)R_data.data() + weight_data_size_g; + const float* fptr = (const float*)R_data.data() + weight_data_size_g * 2; + const float* gptr = (const float*)R_data.data() + weight_data_size_g * 3; + + float* w_iptr = (float*)R2.data(); + float* w_fptr = (float*)R2.data() + weight_data_size_g; + float* w_gptr = (float*)R2.data() + weight_data_size_g * 2; + float* w_optr = (float*)R2.data() + weight_data_size_g * 3; + + memcpy(w_iptr, iptr, weight_data_size_g * sizeof(float)); + memcpy(w_fptr, fptr, weight_data_size_g * sizeof(float)); + memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); + memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); + } + + if (direction == "bidirectional") + { + op->attrs["weight_ih_l0"] = Attribute({4 * hidden_size, input_size}, W2); + op->attrs["weight_hh_l0"] = Attribute({4 * hidden_size, hidden_size}, R2); + + std::vector W2R(4 * hidden_size * input_size); + { + const int weight_data_size_g = hidden_size * input_size; + + const float* iptr = (const float*)W_data.data() + weight_data_size_g * 4; + const float* optr = (const float*)W_data.data() + weight_data_size_g * 5; + const float* fptr = (const float*)W_data.data() + weight_data_size_g * 6; + const float* gptr = (const float*)W_data.data() + weight_data_size_g * 7; + + float* w_iptr = (float*)W2R.data(); + float* w_fptr = (float*)W2R.data() + weight_data_size_g; + float* w_gptr = (float*)W2R.data() + weight_data_size_g * 2; + float* w_optr = (float*)W2R.data() + weight_data_size_g * 3; + + memcpy(w_iptr, iptr, weight_data_size_g * sizeof(float)); + memcpy(w_fptr, fptr, weight_data_size_g * sizeof(float)); + memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); + memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); + } + + std::vector R2R(4 * hidden_size * hidden_size); + { + const int weight_data_size_g = hidden_size * hidden_size; + + const float* iptr = (const float*)R_data.data() + weight_data_size_g * 4; + const float* optr = (const float*)R_data.data() + weight_data_size_g * 5; + const float* fptr = (const float*)R_data.data() + weight_data_size_g * 6; + const float* gptr = (const float*)R_data.data() + weight_data_size_g * 7; + + float* w_iptr = (float*)R2R.data(); + float* w_fptr = (float*)R2R.data() + weight_data_size_g; + float* w_gptr = (float*)R2R.data() + weight_data_size_g * 2; + float* w_optr = (float*)R2R.data() + weight_data_size_g * 3; + + memcpy(w_iptr, iptr, weight_data_size_g * sizeof(float)); + memcpy(w_fptr, fptr, weight_data_size_g * sizeof(float)); + memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); + memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); + } + + op->attrs["weight_ih_l0_reverse"] = Attribute({4 * hidden_size, input_size}, W2R); + op->attrs["weight_hh_l0_reverse"] = Attribute({4 * hidden_size, hidden_size}, R2R); + } + else + { + op->attrs["weight_ih_l0"] = Attribute({4 * hidden_size, input_size}, W2); + op->attrs["weight_hh_l0"] = Attribute({4 * hidden_size, hidden_size}, R2); + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_LSTM_onnx, 10) + +class nn_LSTM_onnx_B : public nn_LSTM_onnx +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +LSTM lstm 4 1 input W R B out %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 1 0 out1 +)PNNXIR"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + if (!nn_LSTM_onnx::match(captured_params, captured_attrs)) + return false; + + const int hidden_size = captured_params.at("lstm.hidden_size").i; + + std::string direction = "forward"; + if (captured_params.find("lstm.direction") != captured_params.end()) + { + direction = captured_params.at("lstm.direction").s; + } + + const int num_directions = direction == "bidirectional" ? 2 : 1; + + const auto& B = captured_attrs.at("B.data"); + + if (B.shape.size() != 2 || B.shape[0] != num_directions || B.shape[1] != 8 * hidden_size) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + nn_LSTM_onnx::write(op, captured_params, captured_attrs); + + const auto& B = captured_attrs.at("B.data"); + + bool has_bias = false; + for (auto b : B.get_float32_data()) + { + if (b != 0.f) + { + has_bias = true; + break; + } + } + + op->params["bias"] = has_bias; + + if (has_bias) + { + const int hidden_size = captured_params.at("lstm.hidden_size").i; + + // split B and reorder IOFG to IFGO + auto B_data = B.get_float32_data(); + + std::vector B2(4 * hidden_size); + std::vector B3(4 * hidden_size); + { + const float* iptr = (const float*)B_data.data(); + const float* optr = (const float*)B_data.data() + hidden_size; + const float* fptr = (const float*)B_data.data() + hidden_size * 2; + const float* gptr = (const float*)B_data.data() + hidden_size * 3; + + float* w_iptr = (float*)B2.data(); + float* w_fptr = (float*)B2.data() + hidden_size; + float* w_gptr = (float*)B2.data() + hidden_size * 2; + float* w_optr = (float*)B2.data() + hidden_size * 3; + + memcpy(w_iptr, iptr, hidden_size * sizeof(float)); + memcpy(w_fptr, fptr, hidden_size * sizeof(float)); + memcpy(w_gptr, gptr, hidden_size * sizeof(float)); + memcpy(w_optr, optr, hidden_size * sizeof(float)); + } + { + const float* iptr = (const float*)B_data.data() + hidden_size * 4; + const float* optr = (const float*)B_data.data() + hidden_size * 5; + const float* fptr = (const float*)B_data.data() + hidden_size * 6; + const float* gptr = (const float*)B_data.data() + hidden_size * 7; + + float* w_iptr = (float*)B3.data(); + float* w_fptr = (float*)B3.data() + hidden_size; + float* w_gptr = (float*)B3.data() + hidden_size * 2; + float* w_optr = (float*)B3.data() + hidden_size * 3; + + memcpy(w_iptr, iptr, hidden_size * sizeof(float)); + memcpy(w_fptr, fptr, hidden_size * sizeof(float)); + memcpy(w_gptr, gptr, hidden_size * sizeof(float)); + memcpy(w_optr, optr, hidden_size * sizeof(float)); + } + + std::string direction = "forward"; + if (captured_params.find("lstm.direction") != captured_params.end()) + { + direction = captured_params.at("lstm.direction").s; + } + + if (direction == "bidirectional") + { + op->attrs["bias_ih_l0"] = Attribute({4 * hidden_size}, B2); + op->attrs["bias_hh_l0"] = Attribute({4 * hidden_size}, B3); + + std::vector B2R(4 * hidden_size); + std::vector B3R(4 * hidden_size); + { + const float* iptr = (const float*)B_data.data() + hidden_size * 8; + const float* optr = (const float*)B_data.data() + hidden_size * 9; + const float* fptr = (const float*)B_data.data() + hidden_size * 10; + const float* gptr = (const float*)B_data.data() + hidden_size * 11; + + float* w_iptr = (float*)B2R.data(); + float* w_fptr = (float*)B2R.data() + hidden_size; + float* w_gptr = (float*)B2R.data() + hidden_size * 2; + float* w_optr = (float*)B2R.data() + hidden_size * 3; + + memcpy(w_iptr, iptr, hidden_size * sizeof(float)); + memcpy(w_fptr, fptr, hidden_size * sizeof(float)); + memcpy(w_gptr, gptr, hidden_size * sizeof(float)); + memcpy(w_optr, optr, hidden_size * sizeof(float)); + } + { + const float* iptr = (const float*)B_data.data() + hidden_size * 12; + const float* optr = (const float*)B_data.data() + hidden_size * 13; + const float* fptr = (const float*)B_data.data() + hidden_size * 14; + const float* gptr = (const float*)B_data.data() + hidden_size * 15; + + float* w_iptr = (float*)B3R.data(); + float* w_fptr = (float*)B3R.data() + hidden_size; + float* w_gptr = (float*)B3R.data() + hidden_size * 2; + float* w_optr = (float*)B3R.data() + hidden_size * 3; + + memcpy(w_iptr, iptr, hidden_size * sizeof(float)); + memcpy(w_fptr, fptr, hidden_size * sizeof(float)); + memcpy(w_gptr, gptr, hidden_size * sizeof(float)); + memcpy(w_optr, optr, hidden_size * sizeof(float)); + } + + op->attrs["bias_ih_l0_reverse"] = Attribute({4 * hidden_size}, B2R); + op->attrs["bias_hh_l0_reverse"] = Attribute({4 * hidden_size}, B3R); + } + else + { + op->attrs["bias_ih_l0"] = Attribute({4 * hidden_size}, B2); + op->attrs["bias_hh_l0"] = Attribute({4 * hidden_size}, B3); + } + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_LSTM_onnx_B, 10) + +class nn_LSTM_onnx_1 : public nn_LSTM_onnx +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 9 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Input input_2 0 1 initial_c +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +LSTM lstm 5 3 input W R initial_h initial_c out outh outc %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 3 0 out1 outh outc +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_LSTM_onnx_1, 10) + +class nn_LSTM_onnx_B1 : public nn_LSTM_onnx_B +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 10 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Input input_2 0 1 initial_c +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +LSTM lstm 6 3 input W R B initial_h initial_c out outh outc %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 3 0 out1 outh outc +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_LSTM_onnx_B1, 10) + +class nn_LSTM_onnx_2 : public nn_LSTM_onnx +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Input input_2 0 1 initial_c +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +LSTM lstm 5 1 input W R initial_h initial_c out %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 1 0 out1 +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_LSTM_onnx_2, 10) + +class nn_LSTM_onnx_B2 : public nn_LSTM_onnx_B +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Input input_2 0 1 initial_c +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +LSTM lstm 6 1 input W R B initial_h initial_c out %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 1 0 out1 +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_LSTM_onnx_B2, 10) + +class nn_LSTM_onnx_3 : public nn_LSTM_onnx +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +LSTM lstm 3 1 input W R out %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 1 0 out2 +)PNNXIR"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + if (!nn_LSTM_onnx::match(captured_params, captured_attrs)) + return false; + + if (captured_params.at("reshape.shape").ai != std::vector{0, 0, -1}) + return false; + + return true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_LSTM_onnx_3, 10) + +class nn_LSTM_onnx_B3 : public nn_LSTM_onnx_B +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +LSTM lstm 4 1 input W R B out %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 1 0 out2 +)PNNXIR"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + if (!nn_LSTM_onnx_B::match(captured_params, captured_attrs)) + return false; + + if (captured_params.at("reshape.shape").ai != std::vector{0, 0, -1}) + return false; + + return true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_LSTM_onnx_B3, 10) + +class nn_LSTM_onnx_4 : public nn_LSTM_onnx_3 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 10 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Input input_2 0 1 initial_c +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +LSTM lstm 5 3 input W R initial_h initial_c out outh outc %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 3 0 out2 outh outc +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_LSTM_onnx_4, 10) + +class nn_LSTM_onnx_B4 : public nn_LSTM_onnx_B3 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +10 11 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Input input_2 0 1 initial_c +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +LSTM lstm 6 3 input W R B initial_h initial_c out outh outc %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 3 0 out2 outh outc +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_LSTM_onnx_B4, 10) + +class nn_LSTM_onnx_5 : public nn_LSTM_onnx_3 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Input input_2 0 1 initial_c +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +LSTM lstm 5 1 input W R initial_h initial_c out %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 1 0 out2 +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_LSTM_onnx_5, 10) + +class nn_LSTM_onnx_B5 : public nn_LSTM_onnx_B3 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +10 9 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Input input_2 0 1 initial_c +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +LSTM lstm 6 1 input W R B initial_h initial_c out %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 1 0 out2 +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_LSTM_onnx_B5, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/nn_RNN.cpp b/tools/pnnx/src/pass_level2/nn_RNN.cpp new file mode 100644 index 00000000000..a3784f44c99 --- /dev/null +++ b/tools/pnnx/src/pass_level2/nn_RNN.cpp @@ -0,0 +1,479 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class nn_RNN_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +RNN rnn 3 1 input W R out %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 1 0 out1 +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.RNN"; + } + + const char* name_str() const + { + return "rnn"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + if (captured_params.find("rnn.hidden_size") == captured_params.end()) + return false; + + const int hidden_size = captured_params.at("rnn.hidden_size").i; + + std::string direction = "forward"; + if (captured_params.find("rnn.direction") != captured_params.end()) + { + direction = captured_params.at("rnn.direction").s; + } + + if (direction != "forward" && direction != "bidirectional") + return false; + + const int num_directions = direction == "bidirectional" ? 2 : 1; + + if (captured_params.find("rnn.activations") != captured_params.end()) + { + const std::vector& acts = captured_params.at("rnn.activations").as; + + if (num_directions == 1) + { + if (acts != std::vector{"Tanh"} && acts != std::vector{"Relu"}) + return false; + } + else // if (num_directions == 2) + { + if (acts != std::vector{"Tanh", "Tanh"} && acts != std::vector{"Relu", "Relu"}) + return false; + } + } + + if (captured_params.find("axes") != captured_params.end()) + { + if (captured_params.at("axes").type == 2 && captured_params.at("axes").i != 1) + return false; + + if (captured_params.at("axes").type == 5 && captured_params.at("axes").ai != std::vector{1}) + return false; + } + + const auto& W = captured_attrs.at("W.data"); + const auto& R = captured_attrs.at("R.data"); + + if (W.shape.size() != 3 || W.shape[0] != num_directions || W.shape[1] != hidden_size) + return false; + + if (R.shape.size() != 3 || R.shape[0] != num_directions || R.shape[1] != hidden_size || R.shape[2] != hidden_size) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + std::string direction = "forward"; + if (captured_params.find("rnn.direction") != captured_params.end()) + { + direction = captured_params.at("rnn.direction").s; + } + + std::string act = "Tanh"; + if (captured_params.find("rnn.activations") != captured_params.end()) + { + act = captured_params.at("rnn.activations").as[0]; + } + + const auto& W = captured_attrs.at("W.data"); + const auto& R = captured_attrs.at("R.data"); + + bool batch_first = false; + if (captured_params.find("rnn.layout") != captured_params.end()) + { + const int layout = captured_params.at("rnn.layout").i; + batch_first = layout == 1; + } + + const int hidden_size = captured_params.at("rnn.hidden_size").i; + + const int input_size = W.shape[2]; + + op->params["input_size"] = input_size; + op->params["hidden_size"] = hidden_size; + op->params["num_layers"] = 1; + op->params["nonlinearity"] = act == "Relu" ? "relu" : "tanh"; + op->params["bias"] = false; + op->params["batch_first"] = batch_first; + op->params["bidirectional"] = direction == "bidirectional" ? true : false; + + // split W R + auto W_data = W.get_float32_data(); + auto R_data = R.get_float32_data(); + + if (direction == "bidirectional") + { + op->attrs["weight_ih_l0"] = Attribute({hidden_size, input_size}, std::vector(&W_data[0], &W_data[hidden_size * input_size])); + op->attrs["weight_hh_l0"] = Attribute({hidden_size, hidden_size}, std::vector(&R_data[0], &R_data[hidden_size * hidden_size])); + + op->attrs["weight_ih_l0_reverse"] = Attribute({hidden_size, input_size}, std::vector(&W_data[hidden_size * input_size], &W_data[hidden_size * input_size * 2])); + op->attrs["weight_hh_l0_reverse"] = Attribute({hidden_size, hidden_size}, std::vector(&R_data[hidden_size * hidden_size], &R_data[hidden_size * hidden_size * 2])); + } + else + { + op->attrs["weight_ih_l0"] = Attribute({hidden_size, input_size}, W_data); + op->attrs["weight_hh_l0"] = Attribute({hidden_size, hidden_size}, R_data); + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_RNN_onnx, 10) + +class nn_RNN_onnx_B : public nn_RNN_onnx +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +RNN rnn 4 1 input W R B out %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 1 0 out1 +)PNNXIR"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + if (!nn_RNN_onnx::match(captured_params, captured_attrs)) + return false; + + const int hidden_size = captured_params.at("rnn.hidden_size").i; + + std::string direction = "forward"; + if (captured_params.find("rnn.direction") != captured_params.end()) + { + direction = captured_params.at("rnn.direction").s; + } + + const int num_directions = direction == "bidirectional" ? 2 : 1; + + const auto& B = captured_attrs.at("B.data"); + + if (B.shape.size() != 2 || B.shape[0] != num_directions || B.shape[1] != 2 * hidden_size) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + nn_RNN_onnx::write(op, captured_params, captured_attrs); + + const auto& B = captured_attrs.at("B.data"); + + bool has_bias = false; + for (auto b : B.get_float32_data()) + { + if (b != 0.f) + { + has_bias = true; + break; + } + } + + op->params["bias"] = has_bias; + + if (has_bias) + { + // split B + auto B_data = B.get_float32_data(); + + const int hidden_size = captured_params.at("rnn.hidden_size").i; + + std::string direction = "forward"; + if (captured_params.find("rnn.direction") != captured_params.end()) + { + direction = captured_params.at("rnn.direction").s; + } + + if (direction == "bidirectional") + { + op->attrs["bias_ih_l0"] = Attribute({hidden_size}, std::vector(&B_data[0], &B_data[hidden_size])); + op->attrs["bias_hh_l0"] = Attribute({hidden_size}, std::vector(&B_data[hidden_size], &B_data[hidden_size * 2])); + + op->attrs["bias_ih_l0_reverse"] = Attribute({hidden_size}, std::vector(&B_data[hidden_size * 2], &B_data[hidden_size * 3])); + op->attrs["bias_hh_l0_reverse"] = Attribute({hidden_size}, std::vector(&B_data[hidden_size * 3], &B_data[hidden_size * 4])); + } + else + { + op->attrs["bias_ih_l0"] = Attribute({hidden_size}, std::vector(&B_data[0], &B_data[hidden_size])); + op->attrs["bias_hh_l0"] = Attribute({hidden_size}, std::vector(&B_data[hidden_size], &B_data[hidden_size * 2])); + } + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_RNN_onnx_B, 10) + +class nn_RNN_onnx_1 : public nn_RNN_onnx +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 7 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +RNN rnn 4 2 input W R initial_h out outh %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 2 0 out1 outh +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_RNN_onnx_1, 10) + +class nn_RNN_onnx_B1 : public nn_RNN_onnx_B +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 8 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +RNN rnn 5 2 input W R B initial_h out outh %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 2 0 out1 outh +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_RNN_onnx_B1, 10) + +class nn_RNN_onnx_2 : public nn_RNN_onnx +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +RNN rnn 4 1 input W R initial_h out %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 1 0 out1 +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_RNN_onnx_2, 10) + +class nn_RNN_onnx_B2 : public nn_RNN_onnx_B +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +RNN rnn 5 1 input W R B initial_h out %*=%* +Squeeze sqz 1 1 out out1 axes=%axes +pnnx.Output output 1 0 out1 +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_RNN_onnx_B2, 10) + +class nn_RNN_onnx_3 : public nn_RNN_onnx +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +RNN rnn 3 1 input W R out %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 1 0 out2 +)PNNXIR"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + if (!nn_RNN_onnx::match(captured_params, captured_attrs)) + return false; + + if (captured_params.at("reshape.shape").ai != std::vector{0, 0, -1}) + return false; + + return true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_RNN_onnx_3, 10) + +class nn_RNN_onnx_B3 : public nn_RNN_onnx_B +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +RNN rnn 4 1 input W R B out %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 1 0 out2 +)PNNXIR"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + if (!nn_RNN_onnx_B::match(captured_params, captured_attrs)) + return false; + + if (captured_params.at("reshape.shape").ai != std::vector{0, 0, -1}) + return false; + + return true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_RNN_onnx_B3, 10) + +class nn_RNN_onnx_4 : public nn_RNN_onnx_3 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 8 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +RNN rnn 4 2 input W R initial_h out outh %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 2 0 out2 outh +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_RNN_onnx_4, 10) + +class nn_RNN_onnx_B4 : public nn_RNN_onnx_B3 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 9 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +RNN rnn 5 2 input W R B initial_h out outh %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 2 0 out2 outh +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_RNN_onnx_B4, 10) + +class nn_RNN_onnx_5 : public nn_RNN_onnx_3 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +RNN rnn 4 1 input W R initial_h out %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 1 0 out2 +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_RNN_onnx_5, 10) + +class nn_RNN_onnx_B5 : public nn_RNN_onnx_B3 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +RNN rnn 5 1 input W R B initial_h out %*=%* +Transpose transpose 1 1 out out1 perm=(0,2,1,3) +Reshape reshape 1 1 out1 out2 %*=%* +pnnx.Output output 1 0 out2 +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_RNN_onnx_B5, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/dead_code_elimination.cpp b/tools/pnnx/src/pass_onnx/dead_code_elimination.cpp index 733aaa96b09..d033bcd4c84 100644 --- a/tools/pnnx/src/pass_onnx/dead_code_elimination.cpp +++ b/tools/pnnx/src/pass_onnx/dead_code_elimination.cpp @@ -145,7 +145,7 @@ void dead_code_elimination(onnx::ModelProto& model) const int graph_value_info_size = graph->value_info_size(); for (int k = j; k < graph_value_info_size - 1; k++) { - graph->mutable_node()->SwapElements(k, k + 1); + graph->mutable_value_info()->SwapElements(k, k + 1); } // ..... ....... j @@ -155,6 +155,36 @@ void dead_code_elimination(onnx::ModelProto& model) } } } + + for (int i = 0; i < graph->node_size(); i++) + { + onnx::NodeProto* node = graph->mutable_node(i); + + for (int j = 0; j < node->output_size(); j++) + { + bool has_dead_output = false; + for (auto x : dead_outputs) + { + if (x == node->output(j)) + { + has_dead_output = true; + break; + } + } + + if (has_dead_output) + { + // drop optional unused outputs + const int node_output_size = node->output_size(); + for (int k = j; k < node_output_size; k++) + { + node->mutable_output()->RemoveLast(); + } + + break; + } + } + } } // collect all dead functions diff --git a/tools/pnnx/src/pass_onnx/eliminate_noop.cpp b/tools/pnnx/src/pass_onnx/eliminate_noop.cpp index f271c9b31d5..5c0ce816700 100644 --- a/tools/pnnx/src/pass_onnx/eliminate_noop.cpp +++ b/tools/pnnx/src/pass_onnx/eliminate_noop.cpp @@ -113,16 +113,16 @@ void eliminate_noop_with_shape(onnx::ModelProto& model) const onnx::NodeProto& node = graph->node(i); const std::string& op_type = node.op_type(); - onnx::ValueInfoProto* input_value = find_value_info_by_name(graph, node.input(0)); - onnx::ValueInfoProto* output_value = find_value_info_by_name(graph, node.output(0)); - - if (!input_value || !output_value) - continue; - bool noop = false; if (op_type == "Cast") { + onnx::ValueInfoProto* input_value = find_value_info_by_name(graph, node.input(0)); + onnx::ValueInfoProto* output_value = find_value_info_by_name(graph, node.output(0)); + + if (!input_value || !output_value) + continue; + if (input_value->type().has_tensor_type() && output_value->type().has_tensor_type()) { if (input_value->type().tensor_type().elem_type() == output_value->type().tensor_type().elem_type()) @@ -132,6 +132,12 @@ void eliminate_noop_with_shape(onnx::ModelProto& model) if (op_type == "Reshape") { + onnx::ValueInfoProto* input_value = find_value_info_by_name(graph, node.input(0)); + onnx::ValueInfoProto* output_value = find_value_info_by_name(graph, node.output(0)); + + if (!input_value || !output_value) + continue; + if (input_value->type().has_tensor_type() && output_value->type().has_tensor_type()) { const onnx::TensorShapeProto& input_tsp = input_value->type().tensor_type().shape(); diff --git a/tools/pnnx/tests/onnx/CMakeLists.txt b/tools/pnnx/tests/onnx/CMakeLists.txt index 229de6020fb..6b9c6db7553 100644 --- a/tools/pnnx/tests/onnx/CMakeLists.txt +++ b/tools/pnnx/tests/onnx/CMakeLists.txt @@ -48,12 +48,14 @@ pnnx_onnx_add_test(nn_ConvTranspose1d) pnnx_onnx_add_test(nn_ConvTranspose2d) pnnx_onnx_add_test(nn_ConvTranspose3d) pnnx_onnx_add_test(nn_GroupNorm) +pnnx_onnx_add_test(nn_GRU) pnnx_onnx_add_test(nn_InstanceNorm1d) pnnx_onnx_add_test(nn_InstanceNorm2d) pnnx_onnx_add_test(nn_InstanceNorm3d) pnnx_onnx_add_test(nn_LayerNorm) pnnx_onnx_add_test(nn_Linear) pnnx_onnx_add_test(nn_LocalResponseNorm) +pnnx_onnx_add_test(nn_LSTM) pnnx_onnx_add_test(nn_MaxPool1d) pnnx_onnx_add_test(nn_MaxPool2d) pnnx_onnx_add_test(nn_MaxPool3d) @@ -63,6 +65,7 @@ pnnx_onnx_add_test(nn_ReLU) pnnx_onnx_add_test(nn_ReplicationPad1d) pnnx_onnx_add_test(nn_ReplicationPad2d) pnnx_onnx_add_test(nn_ReplicationPad3d) +pnnx_onnx_add_test(nn_RNN) pnnx_onnx_add_test(nn_Sigmoid) pnnx_onnx_add_test(nn_Softmax) pnnx_onnx_add_test(nn_Upsample) diff --git a/tools/pnnx/tests/onnx/test_nn_GRU.py b/tools/pnnx/tests/onnx/test_nn_GRU.py new file mode 100644 index 00000000000..98ea95a75b3 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_GRU.py @@ -0,0 +1,75 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.gru_0_0 = nn.GRU(input_size=32, hidden_size=16) + self.gru_0_1 = nn.GRU(input_size=16, hidden_size=16, num_layers=3, bias=False) + self.gru_0_2 = nn.GRU(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) + self.gru_0_3 = nn.GRU(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) + + self.gru_1_0 = nn.GRU(input_size=25, hidden_size=16, batch_first=True) + self.gru_1_1 = nn.GRU(input_size=16, hidden_size=16, num_layers=3, bias=False, batch_first=True) + self.gru_1_2 = nn.GRU(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) + self.gru_1_3 = nn.GRU(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) + + def forward(self, x, y): + x0, h0 = self.gru_0_0(x) + x1, _ = self.gru_0_1(x0) + x2, h2 = self.gru_0_2(x1) + x3, h3 = self.gru_0_3(x1, h2) + + y0, h4 = self.gru_1_0(y) + y1, _ = self.gru_1_1(y0) + y2, h6 = self.gru_1_2(y1) + y3, h7 = self.gru_1_3(y1, h6) + return x2, x3, h0, h2, h3, y2, y3, h4, h6, h7 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(10, 1, 32) + y = torch.rand(1, 12, 25) + + a = net(x, y) + + # export onnx + torch.onnx.export(net, (x, y), "test_nn_GRU.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_GRU.onnx inputshape=[10,1,32],[1,12,25]") + + # pnnx inference + import test_nn_GRU_pnnx + b = test_nn_GRU_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_LSTM.py b/tools/pnnx/tests/onnx/test_nn_LSTM.py new file mode 100644 index 00000000000..7c360ffa14f --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_LSTM.py @@ -0,0 +1,75 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.lstm_0_0 = nn.LSTM(input_size=32, hidden_size=16) + self.lstm_0_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False) + self.lstm_0_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) + self.lstm_0_3 = nn.LSTM(input_size=32, hidden_size=16, num_layers=4, bias=True, bidirectional=True) + + self.lstm_1_0 = nn.LSTM(input_size=25, hidden_size=16, batch_first=True) + self.lstm_1_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False, batch_first=True) + self.lstm_1_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) + self.lstm_1_3 = nn.LSTM(input_size=32, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) + + def forward(self, x, y): + x0, (h0, c0) = self.lstm_0_0(x) + x1, _ = self.lstm_0_1(x0) + x2, (h2, c2) = self.lstm_0_2(x1) + x3, (h3, c3) = self.lstm_0_3(x2, (h2, c2)) + + y0, (h4, c4) = self.lstm_1_0(y) + y1, _ = self.lstm_1_1(y0) + y2, (h6, c6) = self.lstm_1_2(y1) + y3, (h7, c7) = self.lstm_1_3(y2, (h6, c6)) + return x2, x3, h0, h2, h3, c0, c2, c3, y2, y3, h4, h6, h7, c4, c6, c7 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(10, 1, 32) + y = torch.rand(1, 12, 25) + + a = net(x, y) + + # export onnx + torch.onnx.export(net, (x, y), "test_nn_LSTM.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_LSTM.onnx inputshape=[10,1,32],[1,12,25]") + + # pnnx inference + import test_nn_LSTM_pnnx + b = test_nn_LSTM_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_RNN.py b/tools/pnnx/tests/onnx/test_nn_RNN.py new file mode 100644 index 00000000000..059cadd8965 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_RNN.py @@ -0,0 +1,75 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.rnn_0_0 = nn.RNN(input_size=32, hidden_size=16) + self.rnn_0_1 = nn.RNN(input_size=16, hidden_size=16, num_layers=3, nonlinearity='tanh', bias=False) + self.rnn_0_2 = nn.RNN(input_size=16, hidden_size=16, num_layers=4, nonlinearity='relu', bias=True, bidirectional=True) + self.rnn_0_3 = nn.RNN(input_size=16, hidden_size=16, num_layers=4, nonlinearity='tanh', bias=True, bidirectional=True) + + self.rnn_1_0 = nn.RNN(input_size=25, hidden_size=16, batch_first=True) + self.rnn_1_1 = nn.RNN(input_size=16, hidden_size=16, num_layers=3, nonlinearity='tanh', bias=False, batch_first=True) + self.rnn_1_2 = nn.RNN(input_size=16, hidden_size=16, num_layers=4, nonlinearity='relu', bias=True, batch_first=True, bidirectional=True) + self.rnn_1_3 = nn.RNN(input_size=16, hidden_size=16, num_layers=4, nonlinearity='tanh', bias=True, batch_first=True, bidirectional=True) + + def forward(self, x, y): + x0, h0 = self.rnn_0_0(x) + x1, _ = self.rnn_0_1(x0) + x2, h2 = self.rnn_0_2(x1) + x3, h3 = self.rnn_0_3(x1, h2) + + y0, h4 = self.rnn_1_0(y) + y1, _ = self.rnn_1_1(y0) + y2, h6 = self.rnn_1_2(y1) + y3, h7 = self.rnn_1_3(y1, h6) + return x2, x3, h0, h2, h3, y2, y3, h4, h6, h7 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(10, 1, 32) + y = torch.rand(1, 12, 25) + + a = net(x, y) + + # export onnx + torch.onnx.export(net, (x, y), "test_nn_RNN.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_RNN.onnx inputshape=[10,1,32],[1,12,25]") + + # pnnx inference + import test_nn_RNN_pnnx + b = test_nn_RNN_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)