diff --git a/.ci/pnnx.yml b/.ci/pnnx.yml index 082dc8fac2f..e951d7e62c5 100644 --- a/.ci/pnnx.yml +++ b/.ci/pnnx.yml @@ -122,6 +122,7 @@ jobs: - name: test run: | export PYTHONUSERBASE=${{ci.workspace}}/torch-${{matrix.torch-version}} + export LD_LIBRARY_PATH=${{ci.workspace}}/torchvision-${{matrix.torchvision-version}}-install/lib export OMP_NUM_THREADS=1 export MKL_NUM_THREADS=1 export MKL_ENABLE_INSTRUCTIONS=SSE4_2 @@ -131,6 +132,7 @@ jobs: - name: python-pnnx run: | export PYTHONUSERBASE=${{ci.workspace}}/torch-${{matrix.torch-version}} + export LD_LIBRARY_PATH=${{ci.workspace}}/torchvision-${{matrix.torchvision-version}}-install/lib export PNNX_WHEEL_WITHOUT_BUILD=ON cd tools/pnnx/python cp ../build/src/pnnx pnnx/ diff --git a/tools/pnnx/CMakeLists.txt b/tools/pnnx/CMakeLists.txt index 377beb91010..7ef60bdca33 100644 --- a/tools/pnnx/CMakeLists.txt +++ b/tools/pnnx/CMakeLists.txt @@ -26,21 +26,17 @@ include(PNNXPyTorch) set(CMAKE_CXX_STANDARD 14) # set(CMAKE_BUILD_TYPE debug) -#set(CMAKE_BUILD_TYPE relwithdebinfo) +# set(CMAKE_BUILD_TYPE relwithdebinfo) # set(CMAKE_BUILD_TYPE release) option(PNNX_COVERAGE "build for coverage" OFF) -#set(Torch_INSTALL_DIR "/home/nihui/.local/lib/python3.9/site-packages/torch" CACHE STRING "") -#set(Torch_INSTALL_DIR "/home/nihui/osd/pnnx/pytorch-v1.10.0/build/install" CACHE STRING "") # set(Torch_INSTALL_DIR "/home/nihui/osd/pnnx/install" CACHE STRING "") -set(TorchVision_INSTALL_DIR "/home/nihui/osd/vision/build/install" CACHE STRING "") +# set(TorchVision_INSTALL_DIR "/home/nihui/osd/pnnx/install" CACHE STRING "") #set(Torch_DIR "${Torch_INSTALL_DIR}/share/cmake/Torch") set(TorchVision_DIR "${TorchVision_INSTALL_DIR}/share/cmake/TorchVision") -# test if libtorch and protobuf has the same cxxabi version - find_package(Python3 COMPONENTS Interpreter Development) PNNXProbeForPyTorchInstall() @@ -67,11 +63,11 @@ find_library(TORCHVISION_LIBRARY torchvision PATHS "${TorchVision_INSTALL_DIR}/l if(TORCHVISION_LIBRARY) message(STATUS "Found TorchVision: ${TORCHVISION_LIBRARY}") if(APPLE) - list(APPEND TORCHVISION_LIBRARY "-Wl,-force_load,${TORCHVISION_LIBRARY}") + set(TORCHVISION_LIBRARY "-Wl,-force_load,${TORCHVISION_LIBRARY}") elseif(MSVC) - list(APPEND TORCHVISION_LIBRARY "-WHOLEARCHIVE:${TORCHVISION_LIBRARY}") + set(TORCHVISION_LIBRARY "-WHOLEARCHIVE:${TORCHVISION_LIBRARY}") else() - list(APPEND TORCHVISION_LIBRARY "-Wl,--whole-archive ${TORCHVISION_LIBRARY} -Wl,--no-whole-archive") + set(TORCHVISION_LIBRARY "-Wl,--whole-archive ${TORCHVISION_LIBRARY} -Wl,--no-whole-archive") endif() set(TorchVision_FOUND TRUE) message(STATUS "Building with TorchVision") @@ -84,6 +80,7 @@ endif() include_directories(${TORCH_INCLUDE_DIRS}) if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + # test if libtorch and protobuf has the same cxxabi version include(CheckCXXSourceCompiles) set(CMAKE_REQUIRED_FLAGS "${TORCH_CXX_FLAGS}") check_cxx_source_compiles("#include \n#if _GLIBCXX_USE_CXX11_ABI\nint main() { return 0; }\n#endif" PNNX_TORCH_USE_CXX11_ABI) diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 4b97110ba00..cd27ceb2ea8 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -621,20 +621,36 @@ endif() # set(pnnx_pass_onnx_SRCS # pass_onnx/canonicalize.cpp # pass_onnx/dead_code_elimination.cpp +# pass_onnx/eliminate_noop.cpp # pass_onnx/fold_constants.cpp # pass_onnx/inline_containers.cpp # pass_onnx/model_stat.cpp # pass_onnx/shape_inference.cpp +# +# pass_onnx/nn_AdaptiveAvgPool2d.cpp +# pass_onnx/nn_AdaptiveAvgPool3d.cpp +# pass_onnx/nn_AvgPool2d.cpp +# pass_onnx/nn_AvgPool3d.cpp +# pass_onnx/nn_BatchNorm2d.cpp +# pass_onnx/nn_BatchNorm3d.cpp +# pass_onnx/nn_Conv2d.cpp +# pass_onnx/nn_Conv3d.cpp +# pass_onnx/nn_GELU.cpp +# pass_onnx/nn_LayerNorm.cpp +# pass_onnx/nn_Linear.cpp +# pass_onnx/nn_MaxPool2d.cpp +# pass_onnx/nn_MaxPool3d.cpp +# pass_onnx/nn_MultiheadAttention.cpp # ) # # set(onnx2pnnx_SRCS +# pass_onnx.cpp # ${pnnx_pass_onnx_SRCS} # load_onnx.cpp # ) # -# add_library(onnx2pnnx STATIC ${onnx2pnnx_SRCS}) +# add_library(onnx2pnnx OBJECT ${onnx2pnnx_SRCS}) # target_link_libraries(onnx2pnnx PRIVATE onnxproto onnxruntime::onnxruntime) -# # target_compile_definitions(onnx2pnnx PRIVATE BUILD_ONNX2PNNX) # # message(STATUS "Building with dynamo-onnx") diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 47bf293a495..68de54ef8b2 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -943,7 +943,7 @@ static std::string sanitize_identifier(const std::string& s) std::string ss = s; for (size_t i = 0; i < ss.size(); i++) { - if (ss[i] == '.' || ss[i] == ':') + if (ss[i] == '.' || ss[i] == ':' || ss[i] == '/') ss[i] = '_'; } @@ -2771,7 +2771,8 @@ int Graph::parse(const std::string& param) void Operand::remove_consumer(const Operator* c) { auto it = std::find(consumers.begin(), consumers.end(), c); - consumers.erase(it); + if (it != consumers.end()) + consumers.erase(it); } Operator* Graph::new_operator(const std::string& type, const std::string& name) diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index 84bb46e769a..58e5cd638c2 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -38,9 +38,15 @@ class Tensor; #if BUILD_ONNX2PNNX namespace onnx { +class AttributeProto; class TensorProto; class ValueInfoProto; } // namespace onnx +namespace pnnx { +namespace onnx2pnnx { +class OnnxAttributeProxy; +} // namespace onnx2pnnx +} // namespace pnnx #endif // BUILD_ONNX2PNNX namespace pnnx { @@ -187,6 +193,10 @@ class Parameter Parameter(const torch::jit::Node* value_node); Parameter(const torch::jit::Value* value); #endif // BUILD_TORCH2PNNX +#if BUILD_ONNX2PNNX + Parameter(const onnx::AttributeProto& attr); + Parameter(const onnx2pnnx::OnnxAttributeProxy& attr); +#endif // BUILD_ONNX2PNNX static Parameter parse_from_string(const std::string& value); static std::string encode_to_string(const Parameter& param); @@ -325,6 +335,7 @@ class Graph #endif #if BUILD_ONNX2PNNX Operand* new_operand(const onnx::ValueInfoProto& value); + Operand* new_operand(const onnx::TensorProto& t); #endif Operand* new_operand(const std::string& name); diff --git a/tools/pnnx/src/load_onnx.cpp b/tools/pnnx/src/load_onnx.cpp new file mode 100644 index 00000000000..1ea1aa4973a --- /dev/null +++ b/tools/pnnx/src/load_onnx.cpp @@ -0,0 +1,473 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "load_onnx.h" + +#include "onnx.pb.h" + +#include +#include +#include +#include + +#include + +#include "ir.h" + +#include "pass_onnx/canonicalize.h" +#include "pass_onnx/dead_code_elimination.h" +#include "pass_onnx/eliminate_noop.h" +#include "pass_onnx/fold_constants.h" +#include "pass_onnx/inline_containers.h" +#include "pass_onnx/model_stat.h" +#include "pass_onnx/shape_inference.h" + +#include "pass_onnx.h" + +namespace pnnx { + +static size_t type_to_elemsize(int type) +{ + if (type == 1) return 4; + if (type == 2) return 8; + if (type == 3) return 2; + if (type == 4) return 4; + if (type == 5) return 8; + if (type == 6) return 2; + if (type == 7) return 1; + if (type == 8) return 1; + if (type == 9) return 1; + if (type == 10) return 8; + if (type == 11) return 16; + if (type == 12) return 4; + return 0; // null +} + +static int get_onnx_tensor_type(int32_t dt) +{ + if (dt == onnx::TensorProto::FLOAT) return 1; + if (dt == onnx::TensorProto::DOUBLE) return 2; + if (dt == onnx::TensorProto::FLOAT16) return 3; + if (dt == onnx::TensorProto::INT32) return 4; + if (dt == onnx::TensorProto::INT64) return 5; + if (dt == onnx::TensorProto::INT16) return 6; + if (dt == onnx::TensorProto::INT8) return 7; + if (dt == onnx::TensorProto::UINT8) return 8; + if (dt == onnx::TensorProto::BOOL) return 9; + if (dt == onnx::TensorProto::COMPLEX64) return 10; + if (dt == onnx::TensorProto::COMPLEX128) return 11; + return 0; // unknown type +} + +Parameter::Parameter(const onnx::AttributeProto& attr) +{ + type = 0; + + switch (attr.type()) + { + case onnx::AttributeProto::INT: + { + type = 2; + int64_t i64 = attr.i(); + if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + i = (int)i64; + break; + } + case onnx::AttributeProto::FLOAT: + { + type = 3; + f = attr.f(); + break; + } + case onnx::AttributeProto::STRING: + { + type = 4; + s = attr.s(); + break; + } + case onnx::AttributeProto::INTS: + { + type = 5; + for (int i = 0; i < attr.ints().size(); i++) + { + int64_t i64 = attr.ints().at(i); + if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + ai.push_back(i64); + } + break; + } + case onnx::AttributeProto::FLOATS: + { + type = 6; + for (int i = 0; i < attr.floats().size(); i++) + { + float f = attr.floats().at(i); + af.push_back(f); + } + break; + } + case onnx::AttributeProto::STRINGS: + { + type = 7; + for (int i = 0; i < attr.strings().size(); i++) + { + std::string s = attr.strings().at(i); + as.push_back(s); + } + break; + } + case onnx::AttributeProto::TENSOR: + { + const onnx::TensorProto& tensor = attr.t(); + + int64_t numel = 1; + for (int k = 0; k < tensor.dims_size(); k++) + { + numel *= tensor.dims(k); + } + + if (numel == 1) + { + if (tensor.data_type() == onnx::TensorProto::INT32) + { + type = 2; + if (tensor.has_raw_data()) + { + // assert tensor.raw_data().size() == 4 + i = ((int*)tensor.raw_data().data())[0]; + } + else + { + // assert tensor.int32_data().size() == 1 + i = tensor.int32_data().at(0); + } + } + else if (tensor.data_type() == onnx::TensorProto::INT64) + { + type = 2; + int64_t i64; + if (tensor.has_raw_data()) + { + // assert tensor.raw_data().size() == 8 + i64 = ((int64_t*)tensor.raw_data().data())[0]; + } + else + { + // assert tensor.int64_data().size() == 1 + i64 = tensor.int64_data().at(0); + } + if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + i = (int)i64; + } + else if (tensor.data_type() == onnx::TensorProto::FLOAT) + { + type = 3; + if (tensor.has_raw_data()) + { + // assert tensor.raw_data().size() == 4 + f = ((float*)tensor.raw_data().data())[0]; + } + else + { + // assert tensor.float_data().size() == 1 + f = tensor.float_data().at(0); + } + } + else + { + fprintf(stderr, "unknown Node attribute tensor data type %d\n", (int)tensor.data_type()); + } + } + else + { + // constant tensor will become pnnx attribute node later + type = 8; + } + break; + } + default: + { + fprintf(stderr, "unknown Node attribute type %d\n", (int)attr.type()); + break; + } + } +} + +Parameter::Parameter(const onnx2pnnx::OnnxAttributeProxy& attr) + : Parameter(attr.attr) +{ +} + +Attribute::Attribute(const onnx::TensorProto& t) +{ + type = get_onnx_tensor_type(t.data_type()); + + const int ndim = (int)t.dims_size(); + + if (ndim == 0) + { + shape = {1}; + + data.resize(type_to_elemsize(type)); + + if (t.has_raw_data()) + { + // assert t.raw_data().size() == type_to_elemsize(type) + memcpy((void*)data.data(), (const void*)t.raw_data().data(), t.raw_data().size()); + } + else if (t.data_type() == onnx::TensorProto::INT64) + { + int64_t i = t.int64_data().at(0); + memcpy((void*)data.data(), (const void*)&i, data.size()); + } + else if (t.data_type() == onnx::TensorProto::INT32) + { + int i = t.int32_data().at(0); + memcpy((void*)data.data(), (const void*)&i, data.size()); + } + else if (t.data_type() == onnx::TensorProto::DOUBLE) + { + double f = t.double_data().at(0); + memcpy((void*)data.data(), (const void*)&f, data.size()); + } + else if (t.data_type() == onnx::TensorProto::FLOAT) + { + float f = t.float_data().at(0); + memcpy((void*)data.data(), (const void*)&f, data.size()); + } + else + { + fprintf(stderr, "unknown Attribute tensor scalar type %d\n", type); + } + + return; + } + + shape.resize(ndim); + for (int i = 0; i < ndim; i++) + shape[i] = t.dims(i); + + if (shape.size() > 0) + { + data.resize(elemcount() * type_to_elemsize(type)); + + if (t.has_raw_data()) + { + memcpy((void*)data.data(), (const void*)t.raw_data().data(), data.size()); + } + else if (t.data_type() == onnx::TensorProto::INT64) + { + memcpy((void*)data.data(), (const void*)t.int64_data().data(), data.size()); + } + else if (t.data_type() == onnx::TensorProto::INT32) + { + memcpy((void*)data.data(), (const void*)t.int32_data().data(), data.size()); + } + else if (t.data_type() == onnx::TensorProto::DOUBLE) + { + memcpy((void*)data.data(), (const void*)t.double_data().data(), data.size()); + } + else if (t.data_type() == onnx::TensorProto::FLOAT) + { + memcpy((void*)data.data(), (const void*)t.float_data().data(), data.size()); + } + else + { + fprintf(stderr, "unknown Attribute tensor scalar type %d\n", type); + } + } +} + +Operand* Graph::new_operand(const onnx::ValueInfoProto& value) +{ + Operand* r = new Operand; + r->name = value.name(); + + int32_t et = value.type().tensor_type().elem_type(); + r->type = get_onnx_tensor_type(et); + + const onnx::TensorShapeProto& tensor_shape = value.type().tensor_type().shape(); + r->shape.resize(tensor_shape.dim_size()); + for (int z = 0; z < tensor_shape.dim_size(); z++) + { + r->shape[z] = tensor_shape.dim(z).dim_value(); + } + + operands.push_back(r); + return r; +} + +Operand* Graph::new_operand(const onnx::TensorProto& t) +{ + Operand* r = new Operand; + r->name = t.name(); + + r->type = get_onnx_tensor_type(t.data_type()); + + const int ndim = (int)t.dims_size(); + if (ndim == 0) + { + r->shape = {1}; + } + else + { + r->shape.resize(ndim); + for (int i = 0; i < ndim; i++) + r->shape[i] = t.dims(i); + } + + operands.push_back(r); + return r; +} + +static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message) +{ + std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary); + if (!fs.is_open()) + { + fprintf(stderr, "open failed %s\n", filepath); + return false; + } + + google::protobuf::io::IstreamInputStream input(&fs); + google::protobuf::io::CodedInputStream codedstr(&input); + +#if GOOGLE_PROTOBUF_VERSION >= 3011000 + codedstr.SetTotalBytesLimit(INT_MAX); +#else + codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2); +#endif + + bool success = message->ParseFromCodedStream(&codedstr); + + fs.close(); + + return success; +} + +static double get_current_time() +{ + auto now = std::chrono::high_resolution_clock::now(); + auto usec = std::chrono::duration_cast(now.time_since_epoch()); + return usec.count() / 1000.0; +} + +int load_onnx(const std::string& onnxpath, Graph& pnnx_graph) +{ + onnx::ModelProto model; + + bool s1 = read_proto_from_binary(onnxpath.c_str(), &model); + if (!s1) + { + fprintf(stderr, "read_proto_from_binary failed\n"); + return -1; + } + + fprintf(stderr, "############# pass_level0 onnx \n"); + + onnx2pnnx::ModelStat oldstat = onnx2pnnx::get_model_stat(model); + + fprintf(stderr, "%-30s", "inline_containers ... "); + + double t0 = get_current_time(); + + onnx2pnnx::inline_containers(model); + + double t1 = get_current_time(); + + fprintf(stderr, "%10.2fms\n", t1 - t0); + + fprintf(stderr, "%-30s", "eliminate_noop ... "); + + t0 = get_current_time(); + + onnx2pnnx::eliminate_noop(model); + + t1 = get_current_time(); + + fprintf(stderr, "%10.2fms\n", t1 - t0); + + fprintf(stderr, "%-30s", "dead_code_elimination ... "); + + t0 = get_current_time(); + + onnx2pnnx::dead_code_elimination(model); + + t1 = get_current_time(); + + fprintf(stderr, "%10.2fms\n", t1 - t0); + + fprintf(stderr, "%-30s", "fold_constants ... "); + + t0 = get_current_time(); + + onnx2pnnx::fold_constants(model); + + t1 = get_current_time(); + + fprintf(stderr, "%10.2fms\n", t1 - t0); + + fprintf(stderr, "%-30s", "dead_code_elimination ... "); + + t0 = get_current_time(); + + onnx2pnnx::dead_code_elimination(model); + + t1 = get_current_time(); + + fprintf(stderr, "%10.2fms\n", t1 - t0); + + fprintf(stderr, "%-30s", "canonicalize ... "); + + t0 = get_current_time(); + + onnx2pnnx::canonicalize(model); + + t1 = get_current_time(); + + fprintf(stderr, "%10.2fms\n", t1 - t0); + + fprintf(stderr, "%-30s", "shape_inference ... "); + + t0 = get_current_time(); + + onnx2pnnx::shape_inference(model); + + t1 = get_current_time(); + + fprintf(stderr, "%10.2fms\n", t1 - t0); + + // save + std::fstream output("tmp2.onnx", std::ios::out | std::ios::trunc | std::ios::binary); + if (!model.SerializeToOstream(&output)) + { + fprintf(stderr, "write onnx failed\n"); + return -1; + } + + onnx2pnnx::ModelStat newstat = onnx2pnnx::get_model_stat(model); + + onnx2pnnx::print_model_stat(oldstat, newstat); + + fprintf(stderr, "############# pass_level1 onnx\n"); + + pass_onnx(model, pnnx_graph); + + return 0; +} + +} // namespace pnnx diff --git a/tools/pnnx/src/load_onnx.h b/tools/pnnx/src/load_onnx.h new file mode 100644 index 00000000000..e1b99757e86 --- /dev/null +++ b/tools/pnnx/src/load_onnx.h @@ -0,0 +1,26 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_LOAD_ONNX_H +#define PNNX_LOAD_ONNX_H + +#include "ir.h" + +namespace pnnx { + +int load_onnx(const std::string& onnxpath, Graph& g); + +} // namespace pnnx + +#endif // PNNX_LOAD_ONNX_H diff --git a/tools/pnnx/src/pass_level2/F_avg_pool2d.cpp b/tools/pnnx/src/pass_level2/F_avg_pool2d.cpp index 6c92b31808a..82b73c9c650 100644 --- a/tools/pnnx/src/pass_level2/F_avg_pool2d.cpp +++ b/tools/pnnx/src/pass_level2/F_avg_pool2d.cpp @@ -43,4 +43,46 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_avg_pool2d, 10) +class F_avg_pool2d_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input_0 0 1 input +AveragePool op_0 1 1 input out kernel_shape=%kernel_shape strides=%strides pads=%pads ceil_mode=%ceil_mode count_include_pad=%count_include_pad auto_pad=* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.avg_pool2d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + std::vector kernel_shape = captured_params.at("kernel_shape").ai; + std::vector strides = captured_params.at("strides").ai; + std::vector pads = captured_params.at("pads").ai; + int ceil_mode = captured_params.at("ceil_mode").i; + int count_include_pad = captured_params.at("count_include_pad").i; + + if (pads.size() == 4) + { + pads = {pads[0], pads[1]}; + } + + op->params["kernel_size"] = kernel_shape; + op->params["stride"] = strides; + op->params["padding"] = pads; + op->params["ceil_mode"] = (ceil_mode != 0); + op->params["count_include_pad"] = (count_include_pad != 0); + op->params["divisor_override"] = Parameter(); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_avg_pool2d_1, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_conv3d.cpp b/tools/pnnx/src/pass_level2/F_conv3d.cpp index 6f2341ba58c..742417081f4 100644 --- a/tools/pnnx/src/pass_level2/F_conv3d.cpp +++ b/tools/pnnx/src/pass_level2/F_conv3d.cpp @@ -52,4 +52,50 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv3d, 10) +class F_conv3d_0 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +prim::Constant op_0 0 1 transposed value=False +aten::convolution_onnx op_1 4 1 input weight bias transposed out dilations=%dilations groups=%groups output_padding=(0,0,0) pads=%pads strides=%strides +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.conv3d"; + } + + bool match(const std::map& captured_params) const + { + const std::vector& dilations = captured_params.at("dilations").ai; + const std::vector& strides = captured_params.at("strides").ai; + const std::vector& pads = captured_params.at("pads").ai; + return dilations.size() == 3 && strides.size() == 3 && pads.size() == 6; + } + + void write(Operator* op, const std::map& captured_params) const + { + std::vector pads = captured_params.at("pads").ai; + if (pads.size() == 6) + { + pads = {pads[0], pads[1], pads[2]}; + } + + op->params["dilation"] = captured_params.at("dilations"); + op->params["stride"] = captured_params.at("strides"); + op->params["padding"] = pads; + op->params["groups"] = captured_params.at("groups"); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv3d_0, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_layer_norm.cpp b/tools/pnnx/src/pass_level2/F_layer_norm.cpp index ff914e2ca3f..e577cf97a3a 100644 --- a/tools/pnnx/src/pass_level2/F_layer_norm.cpp +++ b/tools/pnnx/src/pass_level2/F_layer_norm.cpp @@ -42,4 +42,55 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_layer_norm, 10) +class F_layer_norm_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 6 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +LayerNormalization op_0 3 3 input weight bias out Mean InvStdDev axis=%axis epsilon=%epsilon stash_type=%stash_type +pnnx.Output output 3 0 out Mean InvStdDev +)PNNXIR"; + } + + const char* type_str() const + { + return "F.layer_norm"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int input_rank = op->inputs[0]->shape.size(); + + int axis = captured_params.at("axis").i; + if (axis < 0) + { + axis = input_rank + axis; + } + + std::vector normalized_shape; + for (int i = axis; i < input_rank; i++) + { + normalized_shape.push_back(op->inputs[0]->shape[i]); + } + + op->params["normalized_shape"] = normalized_shape; + op->params["eps"] = captured_params.at("epsilon"); + + // drop Mean and InvStdDev if not used + if (op->outputs[1]->consumers.empty() && op->outputs[2]->consumers.empty()) + { + op->outputs[1]->producer = 0; + op->outputs[2]->producer = 0; + op->outputs.resize(1); + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_layer_norm_1, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_max_pool2d.cpp b/tools/pnnx/src/pass_level2/F_max_pool2d.cpp index 4b08f577a8b..4b8e0580c31 100644 --- a/tools/pnnx/src/pass_level2/F_max_pool2d.cpp +++ b/tools/pnnx/src/pass_level2/F_max_pool2d.cpp @@ -80,4 +80,46 @@ pnnx.Output output 2 0 out indices REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_max_pool2d_2, 10) +class F_max_pool2d_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 3 +pnnx.Input input_0 0 1 input +aten::max_pool_with_indices_onnx op_1 1 2 input out indices kernel_size=%kernel_size stride=%stride padding=%padding dilation=%dilation ceil_mode=%ceil_mode n_dims_axes=* n_dims_one=* n_dims_zero=* unbatched_rank=* +pnnx.Output output 2 0 out indices +)PNNXIR"; + } + + const char* type_str() const + { + return "F.max_pool2d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + std::vector kernel_size = captured_params.at("kernel_size").ai; + std::vector dilation = captured_params.at("dilation").ai; + std::vector stride = captured_params.at("stride").ai; + std::vector padding = captured_params.at("padding").ai; + int ceil_mode = captured_params.at("ceil_mode").i; + + if (padding.size() == 4) + { + padding = {padding[0], padding[1]}; + } + + op->params["kernel_size"] = kernel_size; + op->params["dilation"] = dilation; + op->params["stride"] = stride; + op->params["padding"] = padding; + op->params["ceil_mode"] = (ceil_mode != 0); + op->params["return_indices"] = (op->outputs.size() != 1); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_max_pool2d_3, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_pad.cpp b/tools/pnnx/src/pass_level2/F_pad.cpp index 398bf43f81b..c4007d25f9b 100644 --- a/tools/pnnx/src/pass_level2/F_pad.cpp +++ b/tools/pnnx/src/pass_level2/F_pad.cpp @@ -44,6 +44,33 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pad, 10) +class F_pad_01 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 pad +aten::constant_pad_nd op_0 2 1 input pad out value=%value +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.pad"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["mode"] = "constant"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pad_01, 10) + class F_pad_1 : public GraphRewriterPass { public: diff --git a/tools/pnnx/src/pass_level2/F_softmax.cpp b/tools/pnnx/src/pass_level2/F_softmax.cpp index 8a9352beba2..8d8068a43c8 100644 --- a/tools/pnnx/src/pass_level2/F_softmax.cpp +++ b/tools/pnnx/src/pass_level2/F_softmax.cpp @@ -39,4 +39,25 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softmax, 10) +class F_softmax_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input_0 0 1 input +aten::softmax_no_dtype op_0 1 1 input out dim=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.softmax"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softmax_1, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_copy.cpp b/tools/pnnx/src/pass_level2/Tensor_copy.cpp index d5369b29e8a..1baa14b6ce2 100644 --- a/tools/pnnx/src/pass_level2/Tensor_copy.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_copy.cpp @@ -39,6 +39,28 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_copy, 20) +class Tensor_copy_01 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 self +pnnx.Input input_1 0 1 src +aten::copy op_1 2 1 self src out non_blocking=* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.copy"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_copy_01, 20) + class Tensor_copy_1 : public GraphRewriterPass { public: diff --git a/tools/pnnx/src/pass_level2/Tensor_expand.cpp b/tools/pnnx/src/pass_level2/Tensor_expand.cpp index 9d860c8319f..23c1af6a863 100644 --- a/tools/pnnx/src/pass_level2/Tensor_expand.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_expand.cpp @@ -39,4 +39,26 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_expand, 20) +class Tensor_expand_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 shape +aten::expand op_1 2 1 input shape out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.expand"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_expand_1, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_select.cpp b/tools/pnnx/src/pass_level2/Tensor_select.cpp index 3ab8a147bb0..07760fcfc99 100644 --- a/tools/pnnx/src/pass_level2/Tensor_select.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_select.cpp @@ -39,4 +39,25 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_select, 20) +class Tensor_select_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input_0 0 1 input +aten::select op_0 1 1 input out dim=%dim index=%index +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.select"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_select_1, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_addmm.cpp b/tools/pnnx/src/pass_level2/torch_addmm.cpp index c8e14a713b0..b15402462bc 100644 --- a/tools/pnnx/src/pass_level2/torch_addmm.cpp +++ b/tools/pnnx/src/pass_level2/torch_addmm.cpp @@ -41,4 +41,27 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_addmm, 20) +class torch_addmm_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 mat1 +pnnx.Input input_2 0 1 mat2 +aten::addmm op_0 3 1 input mat1 mat2 out beta=%beta alpha=%alpha +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.addmm"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_addmm_1, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_cat.cpp b/tools/pnnx/src/pass_level2/torch_cat.cpp index b4d3b5e87d6..2dcbfa4e084 100644 --- a/tools/pnnx/src/pass_level2/torch_cat.cpp +++ b/tools/pnnx/src/pass_level2/torch_cat.cpp @@ -38,4 +38,25 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_cat, 20) +class torch_cat_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input_0 0 1 tensors +aten::cat op_0 1 1 tensors out dim=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.cat"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_cat_1, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_clone.cpp b/tools/pnnx/src/pass_level2/torch_clone.cpp index e645c3cfaf0..82a9d82bb5d 100644 --- a/tools/pnnx/src/pass_level2/torch_clone.cpp +++ b/tools/pnnx/src/pass_level2/torch_clone.cpp @@ -48,4 +48,42 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_clone, 20) +class torch_clone_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::clone op_1 1 1 input out memory_format=%memory_format +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.clone"; + } + + void write(Operator* op, const std::map& captured_params) const + { + if (captured_params.at("memory_format").type == 4 && captured_params.at("memory_format").s.empty()) + { + op->params["memory_format"] = "torch.contiguous_format"; + } + else + { + if (captured_params.at("memory_format").i == 0) + op->params["memory_format"] = "torch.contiguous_format"; + if (captured_params.at("memory_format").i == 1) + op->params["memory_format"] = "torch.preserve_format"; + if (captured_params.at("memory_format").i == 2) + op->params["memory_format"] = "torch.channels_last"; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_clone_1, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_mean.cpp b/tools/pnnx/src/pass_level2/torch_mean.cpp index a43682a8fc6..9fbca3d53c9 100644 --- a/tools/pnnx/src/pass_level2/torch_mean.cpp +++ b/tools/pnnx/src/pass_level2/torch_mean.cpp @@ -40,6 +40,43 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean, 20) +class torch_mean_01 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +aten::mean_dim op_0 2 1 input dim out keepdim=%keepdim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.mean"; + } + + void write(Operator* op, const std::map& captured_params) const + { + bool keepdim; + if (captured_params.at("keepdim").type == 2) + { + keepdim = captured_params.at("keepdim").i ? true : false; + } + else // if (captured_params.at("keepdim").type == 1) + { + keepdim = captured_params.at("keepdim").b; + } + + op->params["keepdim"] = keepdim; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_01, 20) + class torch_mean_1 : public GraphRewriterPass { public: diff --git a/tools/pnnx/src/pass_level2/torch_permute.cpp b/tools/pnnx/src/pass_level2/torch_permute.cpp index cb17d7591c3..37e4e3c2441 100644 --- a/tools/pnnx/src/pass_level2/torch_permute.cpp +++ b/tools/pnnx/src/pass_level2/torch_permute.cpp @@ -44,4 +44,29 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_permute, 20) +class torch_permute_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input_0 0 1 input +aten::permute op_0 1 1 input out dims=%dims +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { +#if TORCH_VERSION_MAJOR >= 2 || TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 9 + return "torch.permute"; +#else + return "Tensor.permute"; +#endif + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_permute_1, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_roll.cpp b/tools/pnnx/src/pass_level2/torch_roll.cpp index 238e4915bf7..c71f7f9395d 100644 --- a/tools/pnnx/src/pass_level2/torch_roll.cpp +++ b/tools/pnnx/src/pass_level2/torch_roll.cpp @@ -39,4 +39,26 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_roll, 20) +class torch_roll_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 shifts +aten::roll_shift_and_dim_onnx op_0 2 1 input shifts out dim=%dims +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.roll"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_roll_1, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_split.cpp b/tools/pnnx/src/pass_level2/torch_split.cpp index 565fd3bcf99..0a87e3d57b4 100644 --- a/tools/pnnx/src/pass_level2/torch_split.cpp +++ b/tools/pnnx/src/pass_level2/torch_split.cpp @@ -39,6 +39,28 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_split, 20) +class torch_split_01 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 tensor +pnnx.Input input_1 0 1 split_size_or_sections +aten::split op_0 2 1 tensor split_size_or_sections out dim=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.split"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_split_01, 20) + class torch_split_1 : public GraphRewriterPass { public: diff --git a/tools/pnnx/src/pass_level2/torch_squeeze.cpp b/tools/pnnx/src/pass_level2/torch_squeeze.cpp index 95289b6ff80..4300b3ef63d 100644 --- a/tools/pnnx/src/pass_level2/torch_squeeze.cpp +++ b/tools/pnnx/src/pass_level2/torch_squeeze.cpp @@ -38,6 +38,27 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_squeeze, 20) +class torch_squeeze_01 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input_0 0 1 input +aten::squeeze_dim op_0 1 1 input out dim=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.squeeze"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_squeeze_01, 20) + class torch_squeeze_0 : public GraphRewriterPass { public: diff --git a/tools/pnnx/src/pass_level2/torch_unbind.cpp b/tools/pnnx/src/pass_level2/torch_unbind.cpp index c973b904b93..396f874fc3f 100644 --- a/tools/pnnx/src/pass_level2/torch_unbind.cpp +++ b/tools/pnnx/src/pass_level2/torch_unbind.cpp @@ -38,4 +38,25 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_unbind, 20) +class torch_unbind_0 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input_0 0 1 input +aten::unbind op_0 1 1 input out dim=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.unbind"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_unbind_0, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_unsqueeze.cpp b/tools/pnnx/src/pass_level2/torch_unsqueeze.cpp index 9acffa1d041..c7f2d8ad467 100644 --- a/tools/pnnx/src/pass_level2/torch_unsqueeze.cpp +++ b/tools/pnnx/src/pass_level2/torch_unsqueeze.cpp @@ -38,4 +38,25 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_unsqueeze, 20) +class torch_unsqueeze_01 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input_0 0 1 input +aten::unsqueeze op_0 1 1 input out dim=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.unsqueeze"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_unsqueeze_01, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_GELU.cpp b/tools/pnnx/src/pass_ncnn/nn_GELU.cpp index bec078bbeb2..bce1cc09202 100644 --- a/tools/pnnx/src/pass_ncnn/nn_GELU.cpp +++ b/tools/pnnx/src/pass_ncnn/nn_GELU.cpp @@ -44,6 +44,32 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_GELU, 20) +class nn_GELU_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.GELU op_0 1 1 input out approximate=* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "GELU"; + } + + const char* name_str() const + { + return "gelu"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_GELU_1, 20) + } // namespace ncnn } // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx.cpp b/tools/pnnx/src/pass_onnx.cpp new file mode 100644 index 00000000000..0c2f4d2e318 --- /dev/null +++ b/tools/pnnx/src/pass_onnx.cpp @@ -0,0 +1,923 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_onnx.h" + +#include "onnx.pb.h" + +#include +#include +#include +#include + +#include + +#include "ir.h" + +namespace pnnx { + +namespace onnx2pnnx { + +static float get_tensor_f(const onnx::TensorProto& tensor) +{ + int64_t numel = 1; + for (int k = 0; k < tensor.dims_size(); k++) + { + numel *= tensor.dims(k); + } + + if (numel != 1) + { + fprintf(stderr, "get_tensor_f numel %ld\n", numel); + } + + if (tensor.data_type() == onnx::TensorProto::FLOAT) + { + if (tensor.has_raw_data()) + { + // assert tensor.raw_data().size() == 4 + return ((float*)tensor.raw_data().data())[0]; + } + + // assert tensor.float_data().size() == 1 + return tensor.float_data().at(0); + } + + // fatal error + fprintf(stderr, "get_tensor_f failed\n"); + return 0.f; +} + +static std::vector get_tensor_af(const onnx::TensorProto& tensor) +{ + if (tensor.dims_size() != 1) + { + fprintf(stderr, "get_tensor_af dims_size %d\n", (int)tensor.dims_size()); + } + + const int64_t numel = tensor.dims(0); + + if (tensor.data_type() == onnx::TensorProto::FLOAT) + { + const float* p = tensor.has_raw_data() ? (float*)tensor.raw_data().data() : tensor.float_data().data(); + std::vector af(numel); + memcpy(af.data(), p, sizeof(float) * numel); + return af; + } + + // fatal error + fprintf(stderr, "get_tensor_af failed\n"); + return std::vector(); +} + +static int64_t get_tensor_i(const onnx::TensorProto& tensor) +{ + int64_t numel = 1; + for (int k = 0; k < tensor.dims_size(); k++) + { + numel *= tensor.dims(k); + } + + if (numel != 1) + { + fprintf(stderr, "get_tensor_i numel %ld\n", numel); + } + + if (tensor.data_type() == onnx::TensorProto::INT32) + { + if (tensor.has_raw_data()) + { + // assert tensor.raw_data().size() == 4 + return ((int*)tensor.raw_data().data())[0]; + } + + // assert tensor.int32_data().size() == 1 + return tensor.int32_data().at(0); + } + + if (tensor.data_type() == onnx::TensorProto::INT64) + { + if (tensor.has_raw_data()) + { + // assert tensor.raw_data().size() == 8 + return ((int64_t*)tensor.raw_data().data())[0]; + } + + // assert tensor.int64_data().size() == 1 + return tensor.int64_data().at(0); + } + + // fatal error + fprintf(stderr, "get_tensor_i failed\n"); + return 0; +} + +static std::vector get_tensor_ai(const onnx::TensorProto& tensor) +{ + if (tensor.dims_size() != 1) + { + fprintf(stderr, "get_tensor_af dims_size %d\n", (int)tensor.dims_size()); + } + + const int64_t numel = tensor.dims(0); + + if (tensor.data_type() == onnx::TensorProto::INT32) + { + const int* p = tensor.has_raw_data() ? (int*)tensor.raw_data().data() : tensor.int32_data().data(); + std::vector ai(numel); + for (int i = 0; i < numel; i++) + ai[i] = p[i]; + return ai; + } + + if (tensor.data_type() == onnx::TensorProto::INT64) + { + const int64_t* p = tensor.has_raw_data() ? (int64_t*)tensor.raw_data().data() : tensor.int64_data().data(); + std::vector ai(numel); + memcpy(ai.data(), p, sizeof(int64_t) * numel); + return ai; + } + + // fatal error + fprintf(stderr, "get_tensor_ai failed\n"); + return std::vector(); +} + +float OnnxAttributeProxy::value_f() const +{ + if (attr.type() == onnx::AttributeProto::FLOAT) + { + return attr.f(); + } + + if (attr.type() == onnx::AttributeProto::TENSOR) + { + return get_tensor_f(attr.t()); + } + + fprintf(stderr, "OnnxAttributeProxy value_f failed\n"); + return 0.f; +} + +int64_t OnnxAttributeProxy::value_i() const +{ + if (attr.type() == onnx::AttributeProto::INT) + { + return attr.i(); + } + + if (attr.type() == onnx::AttributeProto::TENSOR) + { + return get_tensor_i(attr.t()); + } + + fprintf(stderr, "OnnxAttributeProxy value_i failed\n"); + return 0; +} + +std::string OnnxAttributeProxy::value_s() const +{ + if (attr.type() != onnx::AttributeProto::STRING) + fprintf(stderr, "OnnxAttributeProxy value_s failed\n"); + + return attr.s(); +} + +std::vector OnnxAttributeProxy::value_fs() const +{ + if (attr.type() == onnx::AttributeProto::FLOATS) + { + const int size = attr.floats().size(); + std::vector fs(size); + for (int i = 0; i < size; i++) + { + fs[i] = attr.floats().at(i); + } + return fs; + } + + if (attr.type() == onnx::AttributeProto::TENSOR) + { + return get_tensor_af(attr.t()); + } + + fprintf(stderr, "OnnxAttributeProxy value_fs failed\n"); + return std::vector(); +} + +std::vector OnnxAttributeProxy::value_is() const +{ + if (attr.type() == onnx::AttributeProto::INTS) + { + const int size = attr.ints().size(); + std::vector is(size); + for (int i = 0; i < size; i++) + { + is[i] = attr.ints().at(i); + } + return is; + } + + if (attr.type() == onnx::AttributeProto::TENSOR) + { + return get_tensor_ai(attr.t()); + } + + fprintf(stderr, "OnnxAttributeProxy value_is failed\n"); + return std::vector(); +} + +std::vector OnnxAttributeProxy::value_ss() const +{ + if (attr.type() != onnx::AttributeProto::STRINGS) + fprintf(stderr, "OnnxAttributeProxy value_ss failed\n"); + + const int size = attr.strings().size(); + std::vector ss(size); + for (int i = 0; i < size; i++) + { + ss[i] = attr.strings().at(i); + } + return ss; +} + +OnnxNodeProxy::OnnxNodeProxy(const onnx::NodeProto& _node) + : node(_node) +{ + // extract attribute info + for (int i = 0; i < node.attribute_size(); i++) + { + const std::string& name = node.attribute(i).name(); + attributes.insert(std::make_pair(name, i)); + } +} + +bool OnnxNodeProxy::has_attribute(const std::string& name) const +{ + return attributes.count(name); +} + +const OnnxAttributeProxy OnnxNodeProxy::attribute(const std::string& name) const +{ + int attribute_index = attributes.at(name); + return node.attribute(attribute_index); +} + +OnnxFunctionProxy::OnnxFunctionProxy(const onnx::ModelProto& _model, const onnx::NodeProto& _caller, const onnx::FunctionProto& _function) + : model(_model), caller(_caller), function(_function) +{ + for (int i = 0; i < function.node_size(); i++) + { + const std::string& name = function.node(i).name(); + named_nodes.insert(std::make_pair(name, i)); + + const std::string& type = function.node(i).op_type(); + typed_nodes.insert(std::make_pair(type, i)); + } + + for (int i = 0; i < caller.input_size(); i++) + { + const std::string& function_argument = caller.input(i); + + int initializer_index = -1; + for (int j = 0; j < model.graph().initializer_size(); j++) + { + if (model.graph().initializer(j).name() == function_argument) + { + initializer_index = j; + break; + } + } + + const std::string& function_parameter = function.input(i); + initializers.insert(std::make_pair(function_parameter, initializer_index)); + } +} + +bool OnnxFunctionProxy::has_typed_node(const std::string& type) const +{ + return typed_nodes.count(type); +} + +bool OnnxFunctionProxy::has_named_node(const std::string& name) const +{ + return named_nodes.count(name); +} + +const OnnxNodeProxy OnnxFunctionProxy::typed_node(const std::string& type) const +{ + int node_index = typed_nodes.at(type); + return function.node(node_index); +} + +const OnnxNodeProxy OnnxFunctionProxy::named_node(const std::string& name) const +{ + int node_index = named_nodes.at(name); + return function.node(node_index); +} + +const OnnxNodeProxy OnnxFunctionProxy::find_producer(const std::string& name) const +{ + // find Constant node which produces name + for (int i = 0; i < function.node_size(); i++) + { + const onnx::NodeProto& node = function.node(i); + for (int j = 0; j < node.output_size(); j++) + { + if (node.output(j) == name) + { + return node; + } + } + } + + // should never reach here + return function.node(0); +} + +bool OnnxFunctionProxy::has_initializer(const std::string& name) const +{ + return initializers.count(name); +} + +const onnx::TensorProto& OnnxFunctionProxy::initializer(const std::string& name) const +{ + int initializer_index = initializers.at(name); + return model.graph().initializer(initializer_index); +} + +OnnxModelProxy::OnnxModelProxy(const onnx::ModelProto& _model) + : model(_model) +{ + for (int i = 0; i < model.graph().node_size(); i++) + { + const std::string& name = model.graph().node(i).name(); + nodes.insert(std::make_pair(name, i)); + + for (int j = 0; j < model.functions_size(); j++) + { + const std::string& function_name = model.functions(j).name(); + if (function_name == model.graph().node(i).op_type()) + { + functions.insert(std::make_pair(function_name + name, j)); + } + } + } + + for (int i = 0; i < model.graph().input_size(); i++) + { + const std::string& name = model.graph().input(i).name(); + valueinfos.insert(std::make_pair(name, -1)); + } + for (int i = 0; i < model.graph().output_size(); i++) + { + const std::string& name = model.graph().output(i).name(); + valueinfos.insert(std::make_pair(name, -2)); + } + + for (int i = 0; i < model.graph().value_info_size(); i++) + { + const std::string& name = model.graph().value_info(i).name(); + valueinfos.insert(std::make_pair(name, i)); + } + + for (int i = 0; i < model.graph().initializer_size(); i++) + { + const std::string& name = model.graph().initializer(i).name(); + initializers.insert(std::make_pair(name, i)); + } +} + +bool OnnxModelProxy::has_node(const std::string& name) const +{ + return nodes.count(name); +} + +const OnnxNodeProxy OnnxModelProxy::node(const std::string& name) const +{ + int node_index = nodes.at(name); + return model.graph().node(node_index); +} + +bool OnnxModelProxy::has_function(const std::string& name, const std::string& caller) const +{ + return functions.count(name + caller); +} + +const OnnxFunctionProxy OnnxModelProxy::function(const std::string& name, const std::string& caller) const +{ + int function_index = functions.at(name + caller); + return OnnxFunctionProxy(model, node(caller).node, model.functions(function_index)); +} + +bool OnnxModelProxy::has_valueinfo(const std::string& name) const +{ + return valueinfos.count(name); +} + +const onnx::ValueInfoProto& OnnxModelProxy::valueinfo(const std::string& name) const +{ + int valueinfo_index = valueinfos.at(name); + if (valueinfo_index == -1) + { + for (int i = 0; i < model.graph().input_size(); i++) + { + if (model.graph().input(i).name() == name) + return model.graph().input(i); + } + } + if (valueinfo_index == -2) + { + for (int i = 0; i < model.graph().output_size(); i++) + { + if (model.graph().output(i).name() == name) + return model.graph().output(i); + } + } + + return model.graph().value_info(valueinfo_index); +} + +bool OnnxModelProxy::has_initializer(const std::string& name) const +{ + return initializers.count(name); +} + +const onnx::TensorProto& OnnxModelProxy::initializer(const std::string& name) const +{ + int initializer_index = initializers.at(name); + return model.graph().initializer(initializer_index); +} + +FuseFunctionPass::~FuseFunctionPass() +{ +} + +void FuseFunctionPass::write(Operator* /*op*/, const OnnxFunctionProxy& /*function*/) const +{ +} + +static std::vector g_global_pnnx_fuse_function_passes; + +const std::vector& get_global_pnnx_fuse_function_passes() +{ + return g_global_pnnx_fuse_function_passes; +} + +FuseFunctionPassRegister::FuseFunctionPassRegister(const FuseFunctionPass* _pass) + : pass(_pass) +{ + g_global_pnnx_fuse_function_passes.push_back(pass); +} + +FuseFunctionPassRegister::~FuseFunctionPassRegister() +{ + delete pass; +} + +} // namespace onnx2pnnx + +static bool string_starts_with(const std::string& s, const std::string& s2) +{ + return strncmp(s.c_str(), s2.c_str(), s2.size()) == 0; +} + +static void fuse_list_unpack(Graph& graph) +{ + // prim::Constant + aten::getitem ... -> prim::ListUnpack + + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "aten::getitem") + continue; + + Operand* op_in = op->inputs[0]; + + const int item_count = (int)op_in->consumers.size(); + + std::vector getitem_ops(item_count); + + Operator* cur = op; + + bool full_getitem = true; + for (Operator* op2 : op_in->consumers) + { + if (op2->type != "aten::getitem") + { + fprintf(stderr, "unbalanced getitem\n"); + full_getitem = false; + break; + } + + int gi = op2->inputs[1]->producer->params.at("value").i; + getitem_ops[gi] = op2; + + if (std::find(graph.ops.begin(), graph.ops.end(), op2) < std::find(graph.ops.begin(), graph.ops.end(), cur)) + cur = op2; + } + + if (!full_getitem) + continue; + + matched = true; + + // delete all getitem ops and replace with ListUnpack + Operator* op_list_unpack = graph.new_operator_before("prim::ListUnpack", op->name, cur); + + op_list_unpack->inputs.push_back(op_in); + for (auto op2 : getitem_ops) + { + op_in->remove_consumer(op2); + } + op_in->consumers.push_back(op_list_unpack); + + op_list_unpack->outputs.resize(getitem_ops.size()); + for (size_t j = 0; j < getitem_ops.size(); j++) + { + op_list_unpack->outputs[j] = getitem_ops[j]->outputs[0]; + getitem_ops[j]->outputs[0]->producer = op_list_unpack; + } + + for (auto op2 : getitem_ops) + { + op2->inputs[1]->remove_consumer(op2); + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + delete op2; + } + + break; + } + + if (!matched) + break; + } +} + +void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph) +{ + onnx2pnnx::OnnxModelProxy modelproxy(model); + + const onnx::GraphProto& graph = model.graph(); + + for (int i = 0; i < graph.input_size(); i++) + { + const std::string& output = graph.input(i).name(); + + Operator* op = pnnx_graph.new_operator("pnnx.Input", output); + + const onnx::ValueInfoProto& value = modelproxy.valueinfo(output); + + Operand* op_out = pnnx_graph.new_operand(value); + + op_out->producer = op; + op->outputs.push_back(op_out); + } + + for (int i = 0; i < graph.node_size(); i++) + { + const onnx::NodeProto& node = graph.node(i); + + const std::string& op_type = node.op_type(); + + std::string sim_op_type; + + if (node.domain().empty()) + { + // native onnx op + sim_op_type = op_type; + + if (op_type == "SequenceConstruct") + { + sim_op_type = "prim::ListConstruct"; + } + + if (op_type == "Slice") + { + sim_op_type = "aten::slice"; + } + + if (op_type == "Transpose") + { + sim_op_type = "aten::permute"; + } + } + else if (string_starts_with(op_type, "aten_")) + { + // aten_view + sim_op_type = std::string("aten::") + op_type.substr(5); + } + else if (string_starts_with(op_type, "_aten_")) + { + // _aten_roll_shift_and_dim_onnx + sim_op_type = std::string("aten::") + op_type.substr(6); + } + else if (string_starts_with(op_type, "prims_")) + { + // prims_convert_element_type + sim_op_type = std::string("prim::") + op_type.substr(6); + } + else if (string_starts_with(op_type, "nn_")) + { + // torch_nn_modules_conv_Conv2d _conv1_1 + sim_op_type = op_type; + // nn_Conv2d_i -> nn.Conv2d + sim_op_type[2] = '.'; + if (sim_op_type.find_first_of('_') != std::string::npos) + sim_op_type = sim_op_type.substr(0, sim_op_type.find_first_of('_')); + } + else + { + // custom function + sim_op_type = std::string("custom_op.") + op_type; + } + + // fprintf(fp, "%-24s %-8s", sim_op_type.c_str(), node.name().c_str()); + + Operator* op = pnnx_graph.new_operator(sim_op_type, node.name()); + + // bool is_function = modelproxy.has_function(node.op_type(), node.name()); + + bool is_function_op = string_starts_with(sim_op_type, "nn.") || string_starts_with(sim_op_type, "custom_op."); + + bool is_aten_op = string_starts_with(sim_op_type, "aten::"); + + bool is_prim_op = string_starts_with(sim_op_type, "prim::"); + + for (int j = 0; j < node.input_size(); j++) + { + const std::string& input = node.input(j); + + if (modelproxy.has_initializer(input)) + { + // skip function weight + if (is_function_op) + continue; + + const onnx::TensorProto& tensor = modelproxy.initializer(input); + + int64_t numel = 1; + for (int k = 0; k < tensor.dims_size(); k++) + { + numel *= tensor.dims(k); + } + + if (numel == 1) + { + Operator* op_const = pnnx_graph.new_operator_before("prim::Constant", input, op); + + Operand* op_const_out = pnnx_graph.new_operand(input); + + op_const_out->producer = op_const; + op_const->outputs.push_back(op_const_out); + + if (tensor.data_type() == onnx::TensorProto::INT32) + { + if (tensor.has_raw_data()) + { + // assert tensor.raw_data().size() == 4 + op_const->params["value"] = ((int*)tensor.raw_data().data())[0]; + } + else + { + // assert tensor.int32_data().size() == 1 + op_const->params["value"] = tensor.int32_data().at(0); + } + } + else if (tensor.data_type() == onnx::TensorProto::INT64) + { + int64_t i64; + if (tensor.has_raw_data()) + { + // assert tensor.raw_data().size() == 8 + i64 = ((int64_t*)tensor.raw_data().data())[0]; + } + else + { + // assert tensor.int64_data().size() == 1 + i64 = tensor.int64_data().at(0); + } + if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + op_const->params["value"] = (int)i64; + } + else if (tensor.data_type() == onnx::TensorProto::FLOAT) + { + if (tensor.has_raw_data()) + { + // assert tensor.raw_data().size() == 4 + op_const->params["value"] = ((float*)tensor.raw_data().data())[0]; + } + else + { + // assert tensor.float_data().size() == 1 + op_const->params["value"] = tensor.float_data().at(0); + } + } + else if (tensor.data_type() == onnx::TensorProto::BOOL) + { + if (tensor.has_raw_data()) + { + // assert tensor.raw_data().size() == 2 + op_const->params["value"] = ((uint16_t*)tensor.raw_data().data())[0] ? true : false; + } + else + { + // assert tensor.int32_data().size() == 1 + op_const->params["value"] = tensor.int32_data().at(0) ? true : false; + } + } + else + { + fprintf(stderr, "unknown constant scalar type %d\n", (int)tensor.data_type()); + } + } + else if (is_aten_op && tensor.dims_size() == 1 && (tensor.data_type() == onnx::TensorProto::INT32 || tensor.data_type() == onnx::TensorProto::INT64)) + { + // create list expression + Operator* op_const = pnnx_graph.new_operator_before("pnnx.Expression", input, op); + + Operand* op_const_out = pnnx_graph.new_operand(input); + + op_const_out->producer = op_const; + op_const->outputs.push_back(op_const_out); + + const int list_size = tensor.dims(0); + if (tensor.data_type() == onnx::TensorProto::INT32) + { + std::vector ai(list_size); + if (tensor.has_raw_data()) + { + // assert tensor.raw_data().size() == 4 * list_size + memcpy((void*)ai.data(), (int*)tensor.raw_data().data(), sizeof(int) * list_size); + } + else + { + // assert tensor.int32_data().size() == list_size + memcpy((void*)ai.data(), tensor.int32_data().data(), sizeof(int) * list_size); + } + std::string expr = "["; + for (int k = 0; k < (int)ai.size(); k++) + { + expr += std::to_string(ai[k]); + if (k != (int)ai.size() - 1) + expr += ","; + } + expr += "]"; + op_const->params["expr"] = expr; + } + else if (tensor.data_type() == onnx::TensorProto::INT64) + { + std::vector ai(list_size); + if (tensor.has_raw_data()) + { + // assert tensor.raw_data().size() == 8 * list_size + memcpy((void*)ai.data(), (int64_t*)tensor.raw_data().data(), sizeof(int64_t) * list_size); + } + else + { + // assert tensor.int64_data().size() == list_size + memcpy((void*)ai.data(), tensor.int64_data().data(), sizeof(int64_t) * list_size); + } + std::string expr = "["; + for (int k = 0; k < (int)ai.size(); k++) + { + int64_t i64 = ai[k]; + if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + expr += std::to_string(i64); + if (k != (int)ai.size() - 1) + expr += ","; + } + expr += "]"; + op_const->params["expr"] = expr; + } + } + else + { + // create constant for functions + Operator* op_const = pnnx_graph.new_operator_before("pnnx.Attribute", input, op); + + Operand* op_const_out = pnnx_graph.new_operand(tensor); + + op_const_out->producer = op_const; + op_const->outputs.push_back(op_const_out); + + op_const->attrs["data"] = tensor; + } + } + + Operand* op_in = pnnx_graph.get_operand(input); + + op_in->consumers.push_back(op); + op->inputs.push_back(op_in); + } + + for (int j = 0; j < node.output_size(); j++) + { + const std::string& output = node.output(j); + + Operand* op_out = 0; + + if (modelproxy.has_valueinfo(output)) + { + const onnx::ValueInfoProto& value = modelproxy.valueinfo(output); + op_out = pnnx_graph.new_operand(value); + } + else + { + op_out = pnnx_graph.new_operand(output); + } + + op_out->producer = op; + op->outputs.push_back(op_out); + } + + if (is_function_op) + { + const onnx2pnnx::OnnxFunctionProxy function = modelproxy.function(node.op_type(), node.name()); + + for (const auto& ow : onnx2pnnx::get_global_pnnx_fuse_function_passes()) + { + if (sim_op_type != ow->match_type_str()) + continue; + + op->type = ow->type_str(); + ow->write(op, function); + + break; + } + } + else if (is_aten_op) + { + // extract attributes + for (int j = 0; j < node.attribute_size(); j++) + { + const onnx::AttributeProto& attr = node.attribute(j); + + op->params[attr.name()] = attr; + } + + if (op_type == "Slice") + { + // data start end dim step -> data dim start end step + op->inputnames = {"data", "dim", "start", "end", "step"}; + op->inputs = {op->inputs[0], op->inputs[3], op->inputs[1], op->inputs[2], op->inputs[4]}; + } + + if (op_type == "Transpose") + { + op->params["dims"] = op->params["perm"]; + op->params.erase("perm"); + } + } + else if (is_prim_op) + { + // do nothing :) + } + else + { + // onnx native op, extract attributes + for (int j = 0; j < node.attribute_size(); j++) + { + const onnx::AttributeProto& attr = node.attribute(j); + + op->params[attr.name()] = attr; + } + } + } + + for (int i = 0; i < graph.output_size(); i++) + { + const std::string& input = graph.output(i).name(); + + Operator* op = pnnx_graph.new_operator("pnnx.Output", input); + + Operand* op_in = pnnx_graph.get_operand(input); + + op_in->consumers.push_back(op); + op->inputs.push_back(op_in); + } + + // post process + fuse_list_unpack(pnnx_graph); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx.h b/tools/pnnx/src/pass_onnx.h new file mode 100644 index 00000000000..f087dfa4315 --- /dev/null +++ b/tools/pnnx/src/pass_onnx.h @@ -0,0 +1,181 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_PASS_ONNX_H +#define PNNX_PASS_ONNX_H + +#include +#include +#include + +namespace onnx { +class AttributeProto; +class FunctionProto; +class ModelProto; +class NodeProto; +class TensorProto; +class ValueInfoProto; +} // namespace onnx + +namespace pnnx { + +class Operator; +class Graph; + +namespace onnx2pnnx { + +class OnnxAttributeProxy +{ +public: + OnnxAttributeProxy(const onnx::AttributeProto& _attr) + : attr(_attr) + { + } + + operator float() const + { + return value_f(); + } + operator int64_t() const + { + return value_i(); + } + operator std::string() const + { + return value_s(); + } + operator std::vector() const + { + return value_fs(); + } + operator std::vector() const + { + return value_is(); + } + operator std::vector() const + { + return value_ss(); + } + + float value_f() const; + int64_t value_i() const; + std::string value_s() const; + std::vector value_fs() const; + std::vector value_is() const; + std::vector value_ss() const; + +public: + const onnx::AttributeProto& attr; +}; + +class OnnxNodeProxy +{ +public: + OnnxNodeProxy(const onnx::NodeProto& _node); + + bool has_attribute(const std::string& name) const; + const OnnxAttributeProxy attribute(const std::string& name) const; + +public: + const onnx::NodeProto& node; + +protected: + std::unordered_map attributes; +}; + +class OnnxFunctionProxy +{ +public: + OnnxFunctionProxy(const onnx::ModelProto& _model, const onnx::NodeProto& _caller, const onnx::FunctionProto& _function); + + bool has_typed_node(const std::string& type) const; + bool has_named_node(const std::string& name) const; + const OnnxNodeProxy typed_node(const std::string& type) const; + const OnnxNodeProxy named_node(const std::string& name) const; + + const OnnxNodeProxy find_producer(const std::string& name) const; + + bool has_initializer(const std::string& name) const; + const onnx::TensorProto& initializer(const std::string& name) const; + +public: + const onnx::ModelProto& model; + const onnx::NodeProto& caller; + const onnx::FunctionProto& function; + +protected: + std::unordered_map typed_nodes; + std::unordered_map named_nodes; + std::unordered_map initializers; +}; + +class OnnxModelProxy +{ +public: + OnnxModelProxy(const onnx::ModelProto& _model); + + bool has_node(const std::string& name) const; + const OnnxNodeProxy node(const std::string& name) const; + + bool has_function(const std::string& name, const std::string& caller) const; + const OnnxFunctionProxy function(const std::string& name, const std::string& caller) const; + + bool has_valueinfo(const std::string& name) const; + const onnx::ValueInfoProto& valueinfo(const std::string& name) const; + + bool has_initializer(const std::string& name) const; + const onnx::TensorProto& initializer(const std::string& name) const; + +public: + const onnx::ModelProto& model; + +protected: + std::unordered_map nodes; + std::unordered_map functions; + std::unordered_map valueinfos; + std::unordered_map initializers; +}; + +class FuseFunctionPass +{ +public: + virtual ~FuseFunctionPass(); + + virtual const char* match_type_str() const = 0; + + virtual const char* type_str() const = 0; + + virtual void write(Operator* op, const OnnxFunctionProxy& function) const; +}; + +class FuseFunctionPassRegister +{ +public: + FuseFunctionPassRegister(const FuseFunctionPass* pass); + ~FuseFunctionPassRegister(); + const FuseFunctionPass* pass; +}; + +const std::vector& get_global_pnnx_fuse_function_passes(); + +#define REGISTER_GLOBAL_PNNX_FUSE_FUNCTION_PASS(CLASS) \ + static FuseFunctionPassRegister g_global_pnnx_fusefunctionpass_##CLASS##_register(new CLASS); + +} // namespace onnx2pnnx + +void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph); + +} // namespace pnnx + +#endif // PNNX_PASS_ONNX_H diff --git a/tools/pnnx/src/pass_onnx/canonicalize.cpp b/tools/pnnx/src/pass_onnx/canonicalize.cpp new file mode 100644 index 00000000000..9698c9f73f7 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/canonicalize.cpp @@ -0,0 +1,298 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "canonicalize.h" + +#include +#include +#include + +namespace pnnx { + +namespace onnx2pnnx { + +static bool string_starts_with(const std::string& s, const std::string& s2) +{ + return strncmp(s.c_str(), s2.c_str(), s2.size()) == 0; +} + +void canonicalize(onnx::ModelProto& model) +{ + // collect initializers + std::unordered_set initializers; + { + const onnx::GraphProto& graph = model.graph(); + for (int i = 0; i < graph.initializer_size(); i++) + { + initializers.insert(graph.initializer(i).name()); + } + } + + onnx::GraphProto* graph = model.mutable_graph(); + + std::map function_remap; + + std::map input_output_remap; + int input_output_index = 0; + + // canonicalize graph input output + { + for (int i = 0; i < graph->input_size(); i++) + { + onnx::ValueInfoProto* input = graph->mutable_input(i); + + std::string new_name = std::string("in") + std::to_string(i); + + // fprintf(stderr, "%s -> %s\n", input->name().c_str(), new_name.c_str()); + input_output_remap[input->name()] = new_name; + input->set_name(new_name); + } + for (int i = 0; i < graph->output_size(); i++) + { + onnx::ValueInfoProto* output = graph->mutable_output(i); + + std::string new_name = std::string("out") + std::to_string(i); + + // fprintf(stderr, "%s -> %s\n", output->name().c_str(), new_name.c_str()); + input_output_remap[output->name()] = new_name; + output->set_name(new_name); + } + } + + for (int i = 0; i < graph->node_size(); i++) + { + onnx::NodeProto* node = graph->mutable_node(i); + + // simplify type + { + const std::string& op_type = node->op_type(); + + if (node->domain().empty()) + { + // native onnx op + // Constant + node->set_name(op_type + "_" + std::to_string(i)); + } + else if (string_starts_with(op_type, "aten_")) + { + // aten_view + node->set_name(op_type.substr(5) + "_" + std::to_string(i)); + } + else if (string_starts_with(op_type, "_aten_")) + { + node->set_name(op_type.substr(6) + "_" + std::to_string(i)); + } + else if (string_starts_with(op_type, "prims_")) + { + // prims_convert_element_type + node->set_name(op_type.substr(6) + "_" + std::to_string(i)); + } + else if (string_starts_with(op_type, "torch_nn_modules_") && !string_starts_with(op_type, "torch_nn_modules_container_")) + { + // torch_nn_modules_conv_Conv2d _conv1_1 + // torch_nn_modules_batchnorm_BatchNorm2d _bn1_1 + // torch_nn_modules_pooling_MaxPool2d _maxpool_1_3 + // torch_nn_modules_linear_Linear _fc_1 + + if (function_remap.find(op_type) != function_remap.end()) + { + node->set_op_type(function_remap.at(op_type)); + } + else + { + // torch_nn_modules_conv_Conv2d_xyz -> nn_Conv2d_i + char nn_type[256]; + int nconsumed = 0; + sscanf(op_type.c_str() + sizeof("torch_nn_modules_") - 1, "%*[^_]_%255[^_]_%n", nn_type, &nconsumed); + + std::string new_op_type = std::string("nn_") + nn_type + "_" + std::to_string(i); + + function_remap[op_type] = new_op_type; + + node->set_op_type(new_op_type); + } + node->set_name(node->op_type().substr(3)); + } + else + { + // unknown module ? + fprintf(stderr, "unexpected op_type %s\n", op_type.c_str()); + node->set_name(std::string("op_") + std::to_string(i)); + } + } + + // canonicalize name + // node->set_name(std::string("op_") + std::to_string(i)); + + // canonicalize node input output + { + for (int j = 0; j < node->input_size(); j++) + { + const std::string& node_input = node->input(j); + + // some input/output may have empty name, it causes trouble, skip it + if (node_input.empty()) + continue; + + // skip initializer + if (initializers.find(node_input) != initializers.end()) + continue; + + if (input_output_remap.find(node_input) != input_output_remap.end()) + { + node->set_input(j, input_output_remap.at(node_input)); + } + else + { + // fprintf(stderr, "%s -> %s\n", node_input.c_str(), std::to_string(input_output_index).c_str()); + + input_output_remap[node_input] = std::to_string(input_output_index); + node->set_input(j, std::to_string(input_output_index)); + input_output_index++; + } + } + for (int j = 0; j < node->output_size(); j++) + { + const std::string& node_output = node->output(j); + + // some input/output may have empty name, it causes trouble, skip it + if (node_output.empty()) + continue; + + if (input_output_remap.find(node_output) != input_output_remap.end()) + { + node->set_output(j, input_output_remap.at(node_output)); + } + else + { + // fprintf(stderr, "%s -> %s\n", node_output.c_str(), std::to_string(input_output_index).c_str()); + + input_output_remap[node_output] = std::to_string(input_output_index); + node->set_output(j, std::to_string(input_output_index)); + input_output_index++; + } + } + } + } + + // canonicalize all functions + for (int i = 0; i < model.functions_size(); i++) + { + onnx::FunctionProto* function = model.mutable_functions(i); + + if (function_remap.find(function->name()) != function_remap.end()) + { + function->set_name(function_remap.at(function->name())); + } + + if (!string_starts_with(function->name(), "nn_")) + continue; + + // simplify function input + int function_input_index = 0; + int function_output_index = 0; + std::map function_input_output_remap; + for (int j = 0; j < function->input_size(); j++) + { + const std::string& func_input = function->input(j); + + if (initializers.find(func_input) == initializers.end()) + { + // input tensor + std::string new_name = std::string("in") + std::to_string(function_input_index); + function_input_output_remap[func_input] = new_name; + function->set_input(j, new_name); + function_input_index++; + } + else + { + // weights + // layer2.0.bn1.running_mean + size_t last_dot = func_input.find_last_of('.'); + if (last_dot != std::string::npos) + { + std::string new_name = func_input.substr(last_dot + 1); + function_input_output_remap[func_input] = new_name; + function->set_input(j, new_name); + } + } + } + for (int j = 0; j < function->output_size(); j++) + { + const std::string& func_output = function->output(j); + + // output tensor + std::string new_name = std::string("out") + std::to_string(function_output_index); + function_input_output_remap[func_output] = new_name; + function->set_output(j, new_name); + function_output_index++; + } + + for (int j = 0; j < function->node_size(); j++) + { + onnx::NodeProto* node = function->mutable_node(j); + + for (int k = 0; k < node->input_size(); k++) + { + const std::string& input = node->input(k); + + if (function_input_output_remap.find(input) != function_input_output_remap.end()) + { + node->set_input(k, function_input_output_remap[input]); + } + } + for (int k = 0; k < node->output_size(); k++) + { + const std::string& output = node->output(k); + + if (function_input_output_remap.find(output) != function_input_output_remap.end()) + { + node->set_output(k, function_input_output_remap[output]); + } + } + } + } + + // canonicalize all initializers + // for (int i = 0; i < graph->initializer_size(); i++) + // { + // onnx::TensorProto* initializer = graph->mutable_initializer(i); + // + // if (input_output_remap.find(initializer->name()) == input_output_remap.end()) + // { + // // skip initializers inside module function + // continue; + // } + // + // initializer->set_name(input_output_remap.at(initializer->name())); + // } + + // canonicalize all values + for (int i = 0; i < graph->value_info_size(); i++) + { + onnx::ValueInfoProto* value = graph->mutable_value_info(i); + + if (input_output_remap.find(value->name()) == input_output_remap.end()) + { + // skip values inside module function + continue; + } + + value->set_name(input_output_remap.at(value->name())); + } +} + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/canonicalize.h b/tools/pnnx/src/pass_onnx/canonicalize.h new file mode 100644 index 00000000000..a24ad86a9fd --- /dev/null +++ b/tools/pnnx/src/pass_onnx/canonicalize.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +void canonicalize(onnx::ModelProto& model); + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/dead_code_elimination.cpp b/tools/pnnx/src/pass_onnx/dead_code_elimination.cpp new file mode 100644 index 00000000000..dd54f0f2ffc --- /dev/null +++ b/tools/pnnx/src/pass_onnx/dead_code_elimination.cpp @@ -0,0 +1,265 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "dead_code_elimination.h" + +#include +#include + +namespace pnnx { + +namespace onnx2pnnx { + +void dead_code_elimination(onnx::ModelProto& model) +{ + // collect all nodes that have no links with graph outputs + std::vector dead_outputs; + std::vector dead_node_indexes; + { + const onnx::GraphProto& graph = model.graph(); + + std::unordered_set live_inputs; + for (int i = 0; i < graph.output_size(); i++) + { + live_inputs.insert(graph.output(i).name()); + } + + for (int i = graph.node_size() - 1; i >= 0; i--) + { + const onnx::NodeProto& node = graph.node(i); + + bool is_outputs_live = false; + for (int j = 0; j < node.output_size(); j++) + { + if (live_inputs.find(node.output(j)) != live_inputs.end()) + { + is_outputs_live = true; + break; + } + } + + if (is_outputs_live) + { + for (int j = 0; j < node.output_size(); j++) + { + if (live_inputs.find(node.output(j)) == live_inputs.end()) + { + dead_outputs.push_back(node.output(j)); + } + } + + for (int j = 0; j < node.input_size(); j++) + { + live_inputs.insert(node.input(j)); + } + } + else + { + dead_node_indexes.push_back(i); + } + } + } + + // eliminate dead nodes + { + onnx::GraphProto* graph = model.mutable_graph(); + + for (size_t i = 0; i < dead_node_indexes.size(); i++) + { + const int dead_node_index = dead_node_indexes[i]; + + // ..... dni ....... + const int graph_node_size = graph->node_size(); + for (int j = dead_node_index; j < graph_node_size - 1; j++) + { + graph->mutable_node()->SwapElements(j, j + 1); + } + + // ..... ....... dni + graph->mutable_node()->RemoveLast(); + } + } + + // eliminate dead value info + { + onnx::GraphProto* graph = model.mutable_graph(); + + for (size_t i = 0; i < dead_outputs.size(); i++) + { + const std::string& dead_output = dead_outputs[i]; + + for (int j = 0; j < graph->value_info_size(); j++) + { + if (graph->value_info(j).name() == dead_output) + { + // ..... j ....... + const int graph_value_info_size = graph->value_info_size(); + for (int k = j; k < graph_value_info_size - 1; k++) + { + graph->mutable_node()->SwapElements(k, k + 1); + } + + // ..... ....... j + graph->mutable_node()->RemoveLast(); + + break; + } + } + } + } + + // collect all dead functions + std::vector dead_function_indexes; + { + const onnx::GraphProto& graph = model.graph(); + + std::unordered_set live_function_indexes; + for (int i = 0; i < graph.node_size(); i++) + { + const std::string& op_type = graph.node(i).op_type(); + + for (int j = 0; j < model.functions_size(); j++) + { + const onnx::FunctionProto& function = model.functions(j); + + if (function.name() == op_type) + { + live_function_indexes.insert(j); + break; + } + } + } + + // find nested live functions + while (1) + { + bool new_nested_live_function = false; + + for (int i = 0; i < model.functions_size(); i++) + { + if (live_function_indexes.find(i) == live_function_indexes.end()) + continue; + + const onnx::FunctionProto& function = model.functions(i); + + for (int j = 0; j < function.node_size(); j++) + { + const std::string& op_type = function.node(j).op_type(); + + for (int k = 0; k < model.functions_size(); k++) + { + const onnx::FunctionProto& nested_function = model.functions(k); + + if (nested_function.name() == op_type && live_function_indexes.find(k) == live_function_indexes.end()) + { + // nested live function added + live_function_indexes.insert(k); + + new_nested_live_function = true; + } + } + } + } + + if (!new_nested_live_function) + break; + } + + for (int i = model.functions_size() - 1; i >= 0; i--) + { + if (live_function_indexes.find(i) == live_function_indexes.end()) + { + dead_function_indexes.push_back(i); + } + } + } + + // eliminate dead funtions + { + for (size_t i = 0; i < dead_function_indexes.size(); i++) + { + const int dead_function_index = dead_function_indexes[i]; + + // ..... dfi ....... + const int model_functions_size = model.functions_size(); + for (int j = dead_function_index; j < model_functions_size - 1; j++) + { + model.mutable_functions()->SwapElements(j, j + 1); + } + + // ..... ....... dfi + model.mutable_functions()->RemoveLast(); + } + } + + // eliminate dead initializers + { + onnx::GraphProto* graph = model.mutable_graph(); + + std::unordered_set live_inputs; + for (int i = 0; i < graph->node_size(); i++) + { + const onnx::NodeProto& node = graph->node(i); + + for (int j = 0; j < node.input_size(); j++) + { + live_inputs.insert(node.input(j)); + } + } + + // find live inputs in functions + for (int i = 0; i < model.functions_size(); i++) + { + const onnx::FunctionProto& function = model.functions(i); + + for (int j = 0; j < function.node_size(); j++) + { + const onnx::NodeProto& node = function.node(j); + + for (int k = 0; k < node.input_size(); k++) + { + live_inputs.insert(node.input(k)); + } + } + } + + std::vector dead_initializer_indexes; + for (int i = graph->initializer_size() - 1; i >= 0; i--) + { + if (live_inputs.find(graph->initializer(i).name()) == live_inputs.end()) + { + dead_initializer_indexes.push_back(i); + } + } + + for (size_t i = 0; i < dead_initializer_indexes.size(); i++) + { + const int dead_initializer_index = dead_initializer_indexes[i]; + + // ..... dii ....... + const int graph_initializer_size = graph->initializer_size(); + for (int j = dead_initializer_index; j < graph_initializer_size - 1; j++) + { + graph->mutable_initializer()->SwapElements(j, j + 1); + } + + // ..... ....... dii + graph->mutable_initializer()->RemoveLast(); + } + } +} + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/dead_code_elimination.h b/tools/pnnx/src/pass_onnx/dead_code_elimination.h new file mode 100644 index 00000000000..b890b6a7d7c --- /dev/null +++ b/tools/pnnx/src/pass_onnx/dead_code_elimination.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +void dead_code_elimination(onnx::ModelProto& model); + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/eliminate_noop.cpp b/tools/pnnx/src/pass_onnx/eliminate_noop.cpp new file mode 100644 index 00000000000..cf011f9c29b --- /dev/null +++ b/tools/pnnx/src/pass_onnx/eliminate_noop.cpp @@ -0,0 +1,59 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "eliminate_noop.h" + +#include +#include +#include + +#include + +namespace pnnx { + +namespace onnx2pnnx { + +void eliminate_noop(onnx::ModelProto& model) +{ + onnx::GraphProto* graph = model.mutable_graph(); + + for (int i = 0; i < graph->node_size(); i++) + { + const onnx::NodeProto& node = graph->node(i); + const std::string& op_type = node.op_type(); + + if (op_type == "Identity" || op_type == "aten_copy") + { + const std::string& input_name = node.input(0); + const std::string& output_name = node.output(0); + + for (int j = i + 1; j < graph->node_size(); j++) + { + onnx::NodeProto* node2 = graph->mutable_node(j); + + for (int k = 0; k < node2->input_size(); k++) + { + if (node2->input(k) == output_name) + { + node2->set_input(k, input_name); + } + } + } + } + } +} + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/eliminate_noop.h b/tools/pnnx/src/pass_onnx/eliminate_noop.h new file mode 100644 index 00000000000..74d5781a32a --- /dev/null +++ b/tools/pnnx/src/pass_onnx/eliminate_noop.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +void eliminate_noop(onnx::ModelProto& model); + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/fold_constants.cpp b/tools/pnnx/src/pass_onnx/fold_constants.cpp new file mode 100644 index 00000000000..cb5b3a75d95 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/fold_constants.cpp @@ -0,0 +1,435 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fold_constants.h" + +#include +#include +#include + +#include + +namespace pnnx { + +namespace onnx2pnnx { + +static size_t sizeof_onnx_datatype(ONNXTensorElementDataType type) +{ + switch (type) + { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + return 0; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: + return 1; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + return 2; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return 4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: + return 8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: + return 16; + default: + break; + } + + return 0; +} + +void fold_constants(onnx::ModelProto& model) +{ + // collect initializers + std::unordered_set initializers; + { + const onnx::GraphProto& graph = model.graph(); + for (int i = 0; i < graph.initializer_size(); i++) + { + initializers.insert(graph.initializer(i).name()); + } + } + + // collect all outputs that have no links with graph inputs + std::vector foldable_constants; + { + const onnx::GraphProto& graph = model.graph(); + + std::unordered_set foldable_outputs; + std::unordered_set non_foldable_outputs; + for (int i = 0; i < graph.input_size(); i++) + { + non_foldable_outputs.insert(graph.input(i).name()); + } + + for (int i = 0; i < graph.node_size(); i++) + { + const onnx::NodeProto& node = graph.node(i); + + const std::string& op_type = node.op_type(); + + bool is_outputs_foldable = true; + for (int j = 0; j < node.input_size(); j++) + { + if (non_foldable_outputs.find(node.input(j)) != non_foldable_outputs.end()) + { + is_outputs_foldable = false; + break; + } + } + + // TODO whitelist for static shape + // aten::size + // aten::_shape_as_tensor + if (op_type == "aten_new_empty" + || op_type == "aten_new_full" + || op_type == "aten_new_ones" + || op_type == "aten_new_zeros" + || op_type == "aten_empty_like" + || op_type == "aten_full_like" + || op_type == "aten_ones_like" + || op_type == "aten_zeros_like") + { + is_outputs_foldable = true; + } + + // TODO whitelist for static shape + if (op_type == "Shape") + { + is_outputs_foldable = true; + } + + // TODO whitelist for static type + if (op_type == "CastLike") + { + is_outputs_foldable = non_foldable_outputs.find(node.input(0)) == non_foldable_outputs.end(); + } + + if (!is_outputs_foldable) + { + for (int j = 0; j < node.input_size(); j++) + { + if (non_foldable_outputs.find(node.input(j)) == non_foldable_outputs.end()) + { + // some input/output may have empty name, it causes trouble, skip it + if (node.input(j).empty()) + continue; + + foldable_outputs.insert(node.input(j)); + } + } + + for (int j = 0; j < node.output_size(); j++) + { + non_foldable_outputs.insert(node.output(j)); + } + } + } + + // skip initializers + for (const std::string& x : foldable_outputs) + { + if (initializers.find(x) == initializers.end()) + { + foldable_constants.push_back(x); + } + } + } + + if (foldable_constants.empty()) + return; + + onnx::GraphProto* graph = model.mutable_graph(); + + // save original outputs + std::vector orig_outputs; + { + for (int i = 0; i < graph->output_size(); i++) + { + orig_outputs.push_back(graph->output(i).name()); + } + } + + // add foldable outputs to onnx output + { + graph->clear_output(); + + for (size_t i = 0; i < foldable_constants.size(); i++) + { + graph->add_output()->set_name(foldable_constants[i]); + } + } + + // generate temp onnx graph + std::string tmp_onnx_data; + { + std::stringstream tmp_onnx_data_ss; + if (!model.SerializeToOstream(&tmp_onnx_data_ss)) + { + fprintf(stderr, "write onnx failed\n"); + return; + } + + tmp_onnx_data = tmp_onnx_data_ss.str(); + } + + // onnxrt inference + { + const OrtApi* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + + OrtStatus* ort_status = 0; + + OrtEnv* ort_env = 0; + ort_status = ort_api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "pnnx", &ort_env); + if (ort_status) + { + fprintf(stderr, "ort CreateEnv failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + OrtSessionOptions* ort_session_opt = 0; + ort_status = ort_api->CreateSessionOptions(&ort_session_opt); + if (ort_status) + { + fprintf(stderr, "ort CreateSessionOptions failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + ort_status = ort_api->SetSessionGraphOptimizationLevel(ort_session_opt, ORT_DISABLE_ALL); + if (ort_status) + { + fprintf(stderr, "ort SetSessionGraphOptimizationLevel failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + // ort_status = ort_api->SetIntraOpNumThreads(ort_session_opt, 4); + // if (ort_status) + // { + // fprintf(stderr, "ort SetIntraOpNumThreads failed %s\n", ort_api->GetErrorMessage(ort_status)); + // } + // + // ort_status = ort_api->SetInterOpNumThreads(ort_session_opt, 4); + // if (ort_status) + // { + // fprintf(stderr, "ort SetInterOpNumThreads failed %s\n", ort_api->GetErrorMessage(ort_status)); + // } + + OrtSession* ort_session = 0; + ort_status = ort_api->CreateSessionFromArray(ort_env, (const void*)tmp_onnx_data.data(), tmp_onnx_data.size(), ort_session_opt, &ort_session); + if (ort_status) + { + fprintf(stderr, "ort CreateSession failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + OrtRunOptions* ort_run_opt = 0; + ort_status = ort_api->CreateRunOptions(&ort_run_opt); + if (ort_status) + { + fprintf(stderr, "ort CreateRunOptions failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + OrtAllocator* ort_allocator = 0; + ort_status = ort_api->GetAllocatorWithDefaultOptions(&ort_allocator); + if (ort_status) + { + fprintf(stderr, "ort GetAllocatorWithDefaultOptions failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + std::vector input_names; + std::vector inputs; + for (int i = 0; i < graph->input_size(); i++) + { + const onnx::ValueInfoProto& value = graph->input(i); + + std::vector shape; + const onnx::TensorShapeProto& tsp = value.type().tensor_type().shape(); + for (int k = 0; k < tsp.dim_size(); k++) + { + // TODO has_dim_value ? + shape.push_back(tsp.dim(k).dim_value()); + } + + ONNXTensorElementDataType datatype = (ONNXTensorElementDataType)value.type().tensor_type().elem_type(); + + OrtValue* ort_val = 0; + ort_status = ort_api->CreateTensorAsOrtValue(ort_allocator, (const int64_t*)shape.data(), shape.size(), datatype, &ort_val); + if (ort_status) + { + fprintf(stderr, "ort CreateTensorAsOrtValue failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + input_names.push_back(value.name().c_str()); + inputs.push_back(ort_val); + } + + std::vector output_names; + std::vector outputs; + for (size_t i = 0; i < foldable_constants.size(); i++) + { + output_names.push_back(foldable_constants[i].c_str()); + outputs.push_back(0); + } + + ort_status = ort_api->Run(ort_session, ort_run_opt, + input_names.data(), inputs.data(), input_names.size(), + output_names.data(), output_names.size(), outputs.data()); + if (ort_status) + { + fprintf(stderr, "ort Run failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + // TODO get output data + + // graph->clear_output(); + + for (size_t i = 0; i < output_names.size(); i++) + { + OrtTensorTypeAndShapeInfo* info = 0; + ort_status = ort_api->GetTensorTypeAndShape(outputs[i], &info); + if (ort_status) + { + fprintf(stderr, "ort GetTensorTypeAndShape failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + ONNXTensorElementDataType datatype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ort_status = ort_api->GetTensorElementType(info, &datatype); + if (ort_status) + { + fprintf(stderr, "ort GetTensorElementType failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + size_t out_dims = 0; + ort_status = ort_api->GetDimensionsCount(info, &out_dims); + if (ort_status) + { + fprintf(stderr, "ort GetDimensionsCount failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + // fprintf(stderr, " out_dims = %lu\n", out_dims); + + std::vector out_shape; + out_shape.resize(out_dims); + ort_status = ort_api->GetDimensions(info, out_shape.data(), out_dims); + if (ort_status) + { + fprintf(stderr, "ort GetDimensions failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + void* tensor_data = 0; + ort_status = ort_api->GetTensorMutableData(outputs[i], &tensor_data); + if (ort_status) + { + fprintf(stderr, "ort GetTensorMutableData failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + size_t elemcount = 0; + ort_status = ort_api->GetTensorShapeElementCount(info, &elemcount); + if (ort_status) + { + fprintf(stderr, "ort GetTensorShapeElementCount failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + // fprintf(stderr, "%16s = ", output_names[i]); + // for (size_t j = 0; j < out_dims; j++) + // { + // fprintf(stderr, "%lu ", out_shape[j]); + // } + // fprintf(stderr, "\n"); + + // unlink any node that has this output + { + for (int j = 0; j < graph->node_size(); j++) + { + const onnx::NodeProto& node = graph->node(j); + + bool is_producer = false; + int producer_node_output_index = -1; + for (int k = 0; k < node.output_size(); k++) + { + if (node.output(k) == output_names[i]) + { + is_producer = true; + producer_node_output_index = k; + break; + } + } + + if (is_producer) + { + graph->mutable_node(j)->set_output(producer_node_output_index, std::string("pnnx_unlink_") + output_names[i]); + break; + } + } + } + + // create initializer + { + onnx::TensorProto* tp = graph->add_initializer(); + tp->set_name(output_names[i]); + + for (size_t j = 0; j < out_dims; j++) + { + tp->add_dims(out_shape[j]); + } + + tp->set_data_type((int32_t)datatype); + + std::string* data = tp->mutable_raw_data(); + data->resize(sizeof_onnx_datatype(datatype) * elemcount); + memcpy((void*)data->data(), tensor_data, sizeof_onnx_datatype(datatype) * elemcount); + } + + ort_api->ReleaseTensorTypeAndShapeInfo(info); + } + + for (size_t i = 0; i < input_names.size(); i++) + { + ort_api->ReleaseValue(inputs[i]); + } + + for (size_t i = 0; i < output_names.size(); i++) + { + ort_api->ReleaseValue(outputs[i]); + } + + ort_api->ReleaseRunOptions(ort_run_opt); + ort_api->ReleaseSession(ort_session); + ort_api->ReleaseSessionOptions(ort_session_opt); + ort_api->ReleaseEnv(ort_env); + } + + // restore original outputs + { + graph->clear_output(); + + for (size_t i = 0; i < orig_outputs.size(); i++) + { + graph->add_output()->set_name(orig_outputs[i]); + } + } +} + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/fold_constants.h b/tools/pnnx/src/pass_onnx/fold_constants.h new file mode 100644 index 00000000000..9728a7aaba6 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/fold_constants.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +void fold_constants(onnx::ModelProto& model); + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/inline_containers.cpp b/tools/pnnx/src/pass_onnx/inline_containers.cpp new file mode 100644 index 00000000000..a4ea80af614 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/inline_containers.cpp @@ -0,0 +1,189 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "inline_containers.h" + +#include +#include + +namespace pnnx { + +namespace onnx2pnnx { + +static bool string_starts_with(const std::string& s, const std::string& s2) +{ + return strncmp(s.c_str(), s2.c_str(), s2.size()) == 0; +} + +void inline_containers(onnx::ModelProto& model) +{ + onnx::GraphProto* graph = model.mutable_graph(); + + for (int i = 0; i < graph->node_size(); i++) + { + onnx::NodeProto* node = graph->mutable_node(i); + + const std::string& op_type = node->op_type(); + + if (node->domain().empty()) + { + // native onnx op + + // Constant + // fprintf(stderr, " node = onnx %s\n", op_type.c_str()); + continue; + } + + if (string_starts_with(op_type, "torch_nn_modules_") && !string_starts_with(op_type, "torch_nn_modules_container_")) + { + // torch_nn_modules_conv_Conv2d _conv1_1 + // torch_nn_modules_batchnorm_BatchNorm2d _bn1_1 + // torch_nn_modules_pooling_MaxPool2d _maxpool_1_3 + // torch_nn_modules_linear_Linear _fc_1 + + // std::vector tokens = string_split(op_type, '_'); + + // fprintf(stderr, " node = nn.%s\n", tokens[4].c_str()); + continue; + } + + if (string_starts_with(op_type, "aten_") || string_starts_with(op_type, "_aten_")) + { + // aten_view + + // std::vector tokens = string_split(op_type, '_'); + + // fprintf(stderr, " node = aten::%s\n", tokens[1].c_str()); + continue; + } + + if (string_starts_with(op_type, "prims_")) + { + // prims_convert_element_type + continue; + } + + // find function + int function_index = -1; + for (int j = 0; j < model.functions_size(); j++) + { + const onnx::FunctionProto& function = model.functions(j); + if (function.name() == op_type) + { + function_index = j; + break; + } + } + + if (function_index == -1) + { + fprintf(stderr, "no such function with name %s\n", op_type.c_str()); + continue; + } + + // ok, this is a function, inline it at node + // fprintf(stderr, "inline %s\n", op_type.c_str()); + + const onnx::FunctionProto& function = model.functions(function_index); + + // build function input and output name remap + std::map input_output_remap; + { + for (int j = 0; j < node->input_size(); j++) + { + const std::string& node_input = node->input(j); + const std::string& func_input = function.input(j); + + input_output_remap[func_input] = node_input; + } + for (int j = 0; j < node->output_size(); j++) + { + const std::string& node_output = node->output(j); + const std::string& func_output = function.output(j); + + input_output_remap[func_output] = node_output; + } + } + + // append function nodes to graph + { + graph->mutable_node()->Reserve(graph->node_size() + function.node_size()); + for (int j = 0; j < function.node_size(); j++) + { + onnx::NodeProto* inlined_node = graph->add_node(); + inlined_node->CopyFrom(function.node(j)); + + // prefix with caller node name + inlined_node->set_name(node->name() + "/" + inlined_node->name()); + + // reset input output + for (int j = 0; j < inlined_node->input_size(); j++) + { + const std::string& node_input = inlined_node->input(j); + if (input_output_remap.find(node_input) != input_output_remap.end()) + { + inlined_node->set_input(j, input_output_remap.at(node_input)); + } + else + { + // graph->add_value_info()->set_name(node->name() + "/" + node_input); + inlined_node->set_input(j, node->name() + "/" + node_input); + } + } + for (int j = 0; j < inlined_node->output_size(); j++) + { + const std::string& node_output = inlined_node->output(j); + if (input_output_remap.find(node_output) != input_output_remap.end()) + { + inlined_node->set_output(j, input_output_remap.at(node_output)); + } + else + { + // graph->add_value_info()->set_name(node->name() + "/" + node_output); + inlined_node->set_output(j, node->name() + "/" + node_output); + } + } + } + } + + // swap inlined function nodes to caller + { + // ..... cni ....... 0 1 2 3 4 + const int graph_node_size = graph->node_size(); + for (int j = 0; j < function.node_size(); j++) + { + for (int k = graph_node_size - 1; k > i; k--) + { + graph->mutable_node()->SwapElements(k, k - 1); + } + } + + // ..... 0 1 2 3 4 cni ....... + for (int j = i + function.node_size(); j < graph_node_size - 1; j++) + { + graph->mutable_node()->SwapElements(j, j + 1); + } + + // ..... 0 1 2 3 4 ....... cni + graph->mutable_node()->RemoveLast(); + } + + // inlined node may be function + i -= 1; + } +} + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/inline_containers.h b/tools/pnnx/src/pass_onnx/inline_containers.h new file mode 100644 index 00000000000..56b21f47b37 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/inline_containers.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +void inline_containers(onnx::ModelProto& model); + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/model_stat.cpp b/tools/pnnx/src/pass_onnx/model_stat.cpp new file mode 100644 index 00000000000..6c61dfa2bd4 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/model_stat.cpp @@ -0,0 +1,581 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "model_stat.h" + +namespace pnnx { + +namespace onnx2pnnx { + +static bool string_starts_with(const std::string& s, const std::string& s2) +{ + return strncmp(s.c_str(), s2.c_str(), s2.size()) == 0; +} + +ModelStat get_model_stat(const onnx::ModelProto& model) +{ + ModelStat stat; + + const onnx::GraphProto& graph = model.graph(); + + stat.node_size = graph.node_size(); + for (int i = 0; i < model.functions_size(); i++) + { + stat.node_size += model.functions(i).node_size(); + } + + stat.initializer_size = graph.initializer_size(); + stat.functions_size = model.functions_size(); + + for (int i = 0; i < graph.node_size(); i++) + { + const onnx::NodeProto& node = graph.node(i); + + const std::string& op_type = node.op_type(); + + if (node.domain().empty()) + { + // native onnx op + stat.onnx_count += 1; + + if (stat.onnx_op_count.find(op_type) == stat.onnx_op_count.end()) + { + stat.onnx_op_count[op_type] = 1; + } + else + { + stat.onnx_op_count[op_type] = stat.onnx_op_count[op_type] + 1; + } + continue; + } + + if (string_starts_with(op_type, "aten_") || string_starts_with(op_type, "_aten_")) + { + // aten_view + stat.aten_count += 1; + + std::string simname = op_type; + if (simname[0] == '_') + simname = simname.substr(1); + simname[4] = '.'; + + if (stat.aten_op_count.find(simname) == stat.aten_op_count.end()) + { + stat.aten_op_count[simname] = 1; + } + else + { + stat.aten_op_count[simname] = stat.aten_op_count[simname] + 1; + } + continue; + } + + if (string_starts_with(op_type, "prims_")) + { + // prims_convert_element_type + stat.prims_count += 1; + + std::string simname = op_type; + simname[5] = '.'; + + if (stat.prims_op_count.find(simname) == stat.prims_op_count.end()) + { + stat.prims_op_count[simname] = 1; + } + else + { + stat.prims_op_count[simname] = stat.prims_op_count[simname] + 1; + } + continue; + } + + if (string_starts_with(op_type, "torch_nn_modules_") || string_starts_with(op_type, "nn_")) + { + // torch_nn_modules_conv_Conv2d _conv1_1 + stat.nn_module_count += 1; + + std::string simname; + if (string_starts_with(op_type, "nn_")) + { + // nn_Conv2d_i -> nn.Conv2d + simname = op_type; + simname[2] = '.'; + if (simname.find_first_of('_') != std::string::npos) + simname = simname.substr(0, simname.find_first_of('_')); + } + else + { + // torch_nn_modules_conv_Conv2d_xyz -> nn.Conv2d + char nn_type[256]; + sscanf(op_type.c_str() + sizeof("torch_nn_modules_") - 1, "%*[^_]_%255[^_]", nn_type); + simname = std::string("nn.") + nn_type; + } + + if (stat.nn_module_op_count.find(simname) == stat.nn_module_op_count.end()) + { + stat.nn_module_op_count[simname] = 1; + } + else + { + stat.nn_module_op_count[simname] = stat.nn_module_op_count[simname] + 1; + } + continue; + } + + // custom module op + stat.custom_module_count += 1; + } + + // collect called functions + std::unordered_set called_functions; + { + for (int i = 0; i < graph.node_size(); i++) + { + const onnx::NodeProto& node = graph.node(i); + + const std::string& op_type = node.op_type(); + + if (node.domain().empty()) + { + // native onnx op + continue; + } + + if (string_starts_with(op_type, "aten_") || string_starts_with(op_type, "_aten_")) + { + // aten_view + continue; + } + + if (string_starts_with(op_type, "prims_")) + { + // prims_convert_element_type + continue; + } + + if ((string_starts_with(op_type, "torch_nn_modules_") && !string_starts_with(op_type, "torch_nn_modules_container_")) || string_starts_with(op_type, "nn_")) + { + // torch_nn_modules_conv_Conv2d _conv1_1 + continue; + } + + called_functions.insert(op_type); + } + + while (1) + { + bool new_called_function = false; + + for (int i = 0; i < model.functions_size(); i++) + { + const onnx::FunctionProto& function = model.functions(i); + + if (called_functions.find(function.name()) == called_functions.end()) + continue; + + for (int j = 0; j < function.node_size(); j++) + { + const onnx::NodeProto& node = function.node(j); + + const std::string& op_type = node.op_type(); + + if (node.domain().empty()) + { + // native onnx op + continue; + } + + if (string_starts_with(op_type, "aten_") || string_starts_with(op_type, "_aten_")) + { + // aten_view + continue; + } + + if (string_starts_with(op_type, "prims_")) + { + // prims_convert_element_type + continue; + } + + if ((string_starts_with(op_type, "torch_nn_modules_") && !string_starts_with(op_type, "torch_nn_modules_container_")) || string_starts_with(op_type, "nn_")) + { + // torch_nn_modules_conv_Conv2d _conv1_1 + continue; + } + + if (called_functions.find(op_type) == called_functions.end()) + { + called_functions.insert(op_type); + new_called_function = true; + } + } + } + + if (!new_called_function) + break; + } + } + + for (int i = 0; i < model.functions_size(); i++) + { + const onnx::FunctionProto& function = model.functions(i); + + if (called_functions.find(function.name()) == called_functions.end()) + continue; + + for (int j = 0; j < function.node_size(); j++) + { + const onnx::NodeProto& node = function.node(j); + + const std::string& op_type = node.op_type(); + + if (node.domain().empty()) + { + // native onnx op + stat.onnx_count += 1; + + if (stat.onnx_op_count.find(op_type) == stat.onnx_op_count.end()) + { + stat.onnx_op_count[op_type] = 1; + } + else + { + stat.onnx_op_count[op_type] = stat.onnx_op_count[op_type] + 1; + } + continue; + } + + if (string_starts_with(op_type, "aten_") || string_starts_with(op_type, "_aten_")) + { + // aten_view + stat.aten_count += 1; + + std::string simname = op_type; + if (simname[0] == '_') + simname = simname.substr(1); + simname[4] = '.'; + + if (stat.aten_op_count.find(simname) == stat.aten_op_count.end()) + { + stat.aten_op_count[simname] = 1; + } + else + { + stat.aten_op_count[simname] = stat.aten_op_count[simname] + 1; + } + continue; + } + + if (string_starts_with(op_type, "prims_")) + { + // prims_convert_element_type + stat.prims_count += 1; + + std::string simname = op_type; + simname[5] = '.'; + + if (stat.prims_op_count.find(simname) == stat.prims_op_count.end()) + { + stat.prims_op_count[simname] = 1; + } + else + { + stat.prims_op_count[simname] = stat.prims_op_count[simname] + 1; + } + continue; + } + + if (string_starts_with(op_type, "torch_nn_modules_") || string_starts_with(op_type, "nn_")) + { + // torch_nn_modules_conv_Conv2d _conv1_1 + stat.nn_module_count += 1; + + std::string simname; + if (string_starts_with(op_type, "nn_")) + { + simname = op_type; + simname[2] = '.'; + if (simname.find_first_of('_') != std::string::npos) + simname = simname.substr(0, simname.find_first_of('_')); + } + else + { + // torch_nn_modules_conv_Conv2d_xyz -> nn_Conv2d_i + char nn_type[256]; + sscanf(op_type.c_str() + sizeof("torch_nn_modules_") - 1, "%*[^_]_%255[^_]", nn_type); + simname = std::string("nn.") + nn_type; + } + + if (stat.nn_module_op_count.find(simname) == stat.nn_module_op_count.end()) + { + stat.nn_module_op_count[simname] = 1; + } + else + { + stat.nn_module_op_count[simname] = stat.nn_module_op_count[simname] + 1; + } + continue; + } + + // custom module op + stat.custom_module_count += 1; + } + } + + return stat; +} + +void print_model_stat(const ModelStat& oldstat, const ModelStat& newstat) +{ + std::set nn_module_op_count; + std::set aten_op_count; + std::set prims_op_count; + std::set onnx_op_count; + { + for (auto& x : oldstat.nn_module_op_count) + { + nn_module_op_count.insert(x.first); + } + for (auto& x : newstat.nn_module_op_count) + { + nn_module_op_count.insert(x.first); + } + + for (auto& x : oldstat.aten_op_count) + { + aten_op_count.insert(x.first); + } + for (auto& x : newstat.aten_op_count) + { + aten_op_count.insert(x.first); + } + + for (auto& x : oldstat.prims_op_count) + { + prims_op_count.insert(x.first); + } + for (auto& x : newstat.prims_op_count) + { + prims_op_count.insert(x.first); + } + + for (auto& x : oldstat.onnx_op_count) + { + onnx_op_count.insert(x.first); + } + for (auto& x : newstat.onnx_op_count) + { + onnx_op_count.insert(x.first); + } + } + + // resolve longest text + int max_op_name_length = 16; + for (auto& x : nn_module_op_count) + { + max_op_name_length = std::max(max_op_name_length, (int)x.size()); + } + for (auto& x : aten_op_count) + { + max_op_name_length = std::max(max_op_name_length, (int)x.size()); + } + for (auto& x : prims_op_count) + { + max_op_name_length = std::max(max_op_name_length, (int)x.size()); + } + for (auto& x : onnx_op_count) + { + max_op_name_length = std::max(max_op_name_length, (int)x.size()); + } + + fprintf(stderr, "┌─"); + for (int i = 0; i < max_op_name_length; i++) + fprintf(stderr, "─"); + fprintf(stderr, "─┬──────────┬──────────┐\n"); + + fprintf(stderr, "│ %-*s │ orig │ opt │\n", max_op_name_length, ""); + + fprintf(stderr, "├─"); + for (int i = 0; i < max_op_name_length; i++) + fprintf(stderr, "─"); + fprintf(stderr, "─┼──────────┼──────────┤\n"); + + if (newstat.node_size < oldstat.node_size) + fprintf(stderr, "│ %-*s │ %-8d │ \033[32m%-8d\033[0m │\n", max_op_name_length, "node", oldstat.node_size, newstat.node_size); + else + fprintf(stderr, "│ %-*s │ %-8d │ %-8d │\n", max_op_name_length, "node", oldstat.node_size, newstat.node_size); + + fprintf(stderr, "│ %-*s │ %-8d │ %-8d │\n", max_op_name_length, "initializer", oldstat.initializer_size, newstat.initializer_size); + + if (newstat.functions_size < oldstat.functions_size) + fprintf(stderr, "│ %-*s │ %-8d │ \033[32m%-8d\033[0m │\n", max_op_name_length, "functions", oldstat.functions_size, newstat.functions_size); + else + fprintf(stderr, "│ %-*s │ %-8d │ %-8d │\n", max_op_name_length, "functions", oldstat.functions_size, newstat.functions_size); + + fprintf(stderr, "├─"); + for (int i = 0; i < max_op_name_length; i++) + fprintf(stderr, "─"); + fprintf(stderr, "─┼──────────┼──────────┤\n"); + + if (newstat.nn_module_count < oldstat.nn_module_count) + fprintf(stderr, "│ %-*s │ %-8d │ \033[32m%-8d\033[0m │\n", max_op_name_length, "nn module op", oldstat.nn_module_count, newstat.nn_module_count); + else + fprintf(stderr, "│ %-*s │ %-8d │ %-8d │\n", max_op_name_length, "nn module op", oldstat.nn_module_count, newstat.nn_module_count); + + if (newstat.custom_module_count < oldstat.custom_module_count) + fprintf(stderr, "│ %-*s │ %-8d │ \033[32m%-8d\033[0m │\n", max_op_name_length, "custom module op", oldstat.custom_module_count, newstat.custom_module_count); + else + fprintf(stderr, "│ %-*s │ %-8d │ %-8d │\n", max_op_name_length, "custom module op", oldstat.custom_module_count, newstat.custom_module_count); + + if (newstat.aten_count < oldstat.aten_count) + fprintf(stderr, "│ %-*s │ %-8d │ \033[32m%-8d\033[0m │\n", max_op_name_length, "aten op", oldstat.aten_count, newstat.aten_count); + else + fprintf(stderr, "│ %-*s │ %-8d │ %-8d │\n", max_op_name_length, "aten op", oldstat.aten_count, newstat.aten_count); + + if (newstat.prims_count < oldstat.prims_count) + fprintf(stderr, "│ %-*s │ %-8d │ \033[32m%-8d\033[0m │\n", max_op_name_length, "prims op", oldstat.prims_count, newstat.prims_count); + else + fprintf(stderr, "│ %-*s │ %-8d │ %-8d │\n", max_op_name_length, "prims op", oldstat.prims_count, newstat.prims_count); + + if (newstat.onnx_count < oldstat.onnx_count) + fprintf(stderr, "│ %-*s │ %-8d │ \033[32m%-8d\033[0m │\n", max_op_name_length, "onnx native op", oldstat.onnx_count, newstat.onnx_count); + else + fprintf(stderr, "│ %-*s │ %-8d │ %-8d │\n", max_op_name_length, "onnx native op", oldstat.onnx_count, newstat.onnx_count); + + fprintf(stderr, "├─"); + for (int i = 0; i < max_op_name_length; i++) + fprintf(stderr, "─"); + fprintf(stderr, "─┼──────────┼──────────┤\n"); + + // merge nn_module_op_count + { + for (auto x : nn_module_op_count) + { + int oldcount = 0; + int newcount = 0; + if (oldstat.nn_module_op_count.find(x) != oldstat.nn_module_op_count.end()) + { + oldcount = oldstat.nn_module_op_count.at(x); + } + if (newstat.nn_module_op_count.find(x) != newstat.nn_module_op_count.end()) + { + newcount = newstat.nn_module_op_count.at(x); + } + + if (newcount < oldcount) + fprintf(stderr, "│ %-*s │ %-8d │ \033[32m%-8d\033[0m │\n", max_op_name_length, x.c_str(), oldcount, newcount); + else + fprintf(stderr, "│ %-*s │ %-8d │ %-8d │\n", max_op_name_length, x.c_str(), oldcount, newcount); + } + + if (!nn_module_op_count.empty()) + { + fprintf(stderr, "├─"); + for (int i = 0; i < max_op_name_length; i++) + fprintf(stderr, "─"); + fprintf(stderr, "─┼──────────┼──────────┤\n"); + } + } + + // merge aten_op_count + { + for (auto x : aten_op_count) + { + int oldcount = 0; + int newcount = 0; + if (oldstat.aten_op_count.find(x) != oldstat.aten_op_count.end()) + { + oldcount = oldstat.aten_op_count.at(x); + } + if (newstat.aten_op_count.find(x) != newstat.aten_op_count.end()) + { + newcount = newstat.aten_op_count.at(x); + } + + if (newcount < oldcount) + fprintf(stderr, "│ %-*s │ %-8d │ \033[32m%-8d\033[0m │\n", max_op_name_length, x.c_str(), oldcount, newcount); + else + fprintf(stderr, "│ %-*s │ %-8d │ %-8d │\n", max_op_name_length, x.c_str(), oldcount, newcount); + } + + if (!aten_op_count.empty()) + { + fprintf(stderr, "├─"); + for (int i = 0; i < max_op_name_length; i++) + fprintf(stderr, "─"); + fprintf(stderr, "─┼──────────┼──────────┤\n"); + } + } + + // merge prims_op_count + { + for (auto x : prims_op_count) + { + int oldcount = 0; + int newcount = 0; + if (oldstat.prims_op_count.find(x) != oldstat.prims_op_count.end()) + { + oldcount = oldstat.prims_op_count.at(x); + } + if (newstat.prims_op_count.find(x) != newstat.prims_op_count.end()) + { + newcount = newstat.prims_op_count.at(x); + } + + if (newcount < oldcount) + fprintf(stderr, "│ %-*s │ %-8d │ \033[32m%-8d\033[0m │\n", max_op_name_length, x.c_str(), oldcount, newcount); + else + fprintf(stderr, "│ %-*s │ %-8d │ %-8d │\n", max_op_name_length, x.c_str(), oldcount, newcount); + } + + if (!prims_op_count.empty()) + { + fprintf(stderr, "├─"); + for (int i = 0; i < max_op_name_length; i++) + fprintf(stderr, "─"); + fprintf(stderr, "─┼──────────┼──────────┤\n"); + } + } + + // merge onnx_op_count + { + for (auto x : onnx_op_count) + { + int oldcount = 0; + int newcount = 0; + if (oldstat.onnx_op_count.find(x) != oldstat.onnx_op_count.end()) + { + oldcount = oldstat.onnx_op_count.at(x); + } + if (newstat.onnx_op_count.find(x) != newstat.onnx_op_count.end()) + { + newcount = newstat.onnx_op_count.at(x); + } + + if (newcount < oldcount) + fprintf(stderr, "│ %-*s │ %-8d │ \033[32m%-8d\033[0m │\n", max_op_name_length, x.c_str(), oldcount, newcount); + else + fprintf(stderr, "│ %-*s │ %-8d │ %-8d │\n", max_op_name_length, x.c_str(), oldcount, newcount); + } + } + + fprintf(stderr, "└─"); + for (int i = 0; i < max_op_name_length; i++) + fprintf(stderr, "─"); + fprintf(stderr, "─┴──────────┴──────────┘\n"); +} + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/model_stat.h b/tools/pnnx/src/pass_onnx/model_stat.h new file mode 100644 index 00000000000..dd62e67a1bc --- /dev/null +++ b/tools/pnnx/src/pass_onnx/model_stat.h @@ -0,0 +1,58 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +struct ModelStat +{ + ModelStat() + { + node_size = 0; + initializer_size = 0; + functions_size = 0; + + nn_module_count = 0; + custom_module_count = 0; + aten_count = 0; + prims_count = 0; + onnx_count = 0; + } + + int node_size; + int initializer_size; + int functions_size; + + int nn_module_count; + int custom_module_count; + int aten_count; + int prims_count; + int onnx_count; + + std::map nn_module_op_count; + std::map aten_op_count; + std::map prims_op_count; + std::map onnx_op_count; +}; + +ModelStat get_model_stat(const onnx::ModelProto& model); + +void print_model_stat(const ModelStat& oldstat, const ModelStat& newstat); + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/nn_AdaptiveAvgPool2d.cpp b/tools/pnnx/src/pass_onnx/nn_AdaptiveAvgPool2d.cpp new file mode 100644 index 00000000000..0e8851f05f2 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/nn_AdaptiveAvgPool2d.cpp @@ -0,0 +1,52 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_onnx.h" +#include "ir.h" + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +class AdaptiveAvgPool2d : public FuseFunctionPass +{ +public: + const char* match_type_str() const + { + return "nn.AdaptiveAvgPool2d"; + } + + const char* type_str() const + { + return "nn.AdaptiveAvgPool2d"; + } + + void write(Operator* op, const OnnxFunctionProxy& function) const + { + const std::vector& out_shape = op->outputs[0]->shape; + + if (out_shape.size() == 3) + op->params["output_size"] = std::vector {out_shape[1], out_shape[2]}; + else // if (out_shape.size() == 4) + op->params["output_size"] = std::vector {out_shape[2], out_shape[3]}; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_FUNCTION_PASS(AdaptiveAvgPool2d) + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/nn_AdaptiveAvgPool3d.cpp b/tools/pnnx/src/pass_onnx/nn_AdaptiveAvgPool3d.cpp new file mode 100644 index 00000000000..070981e1d64 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/nn_AdaptiveAvgPool3d.cpp @@ -0,0 +1,52 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_onnx.h" +#include "ir.h" + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +class AdaptiveAvgPool3d : public FuseFunctionPass +{ +public: + const char* match_type_str() const + { + return "nn.AdaptiveAvgPool3d"; + } + + const char* type_str() const + { + return "nn.AdaptiveAvgPool3d"; + } + + void write(Operator* op, const OnnxFunctionProxy& function) const + { + const std::vector& out_shape = op->outputs[0]->shape; + + if (out_shape.size() == 4) + op->params["output_size"] = std::vector {out_shape[1], out_shape[2], out_shape[3]}; + else // if (out_shape.size() == 5) + op->params["output_size"] = std::vector {out_shape[2], out_shape[3], out_shape[4]}; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_FUNCTION_PASS(AdaptiveAvgPool3d) + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/nn_AvgPool2d.cpp b/tools/pnnx/src/pass_onnx/nn_AvgPool2d.cpp new file mode 100644 index 00000000000..5a006fe3709 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/nn_AvgPool2d.cpp @@ -0,0 +1,65 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_onnx.h" +#include "ir.h" + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +class AvgPool2d : public FuseFunctionPass +{ +public: + const char* match_type_str() const + { + return "nn.AvgPool2d"; + } + + const char* type_str() const + { + return "nn.AvgPool2d"; + } + + void write(Operator* op, const OnnxFunctionProxy& function) const + { + const OnnxNodeProxy averagepool = function.typed_node("AveragePool"); + + std::vector kernel_shape = averagepool.attribute("kernel_shape"); + std::vector strides = averagepool.attribute("strides"); + std::vector pads = averagepool.attribute("pads"); + int64_t ceil_mode = averagepool.attribute("ceil_mode"); + int64_t count_include_pad = averagepool.attribute("count_include_pad"); + + if (pads.size() == 4) + { + pads = {pads[0], pads[1]}; + } + + op->params["kernel_size"] = kernel_shape; + op->params["stride"] = strides; + op->params["padding"] = pads; + op->params["ceil_mode"] = (ceil_mode != 0); + op->params["count_include_pad"] = (count_include_pad != 0); + op->params["divisor_override"] = Parameter(); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_FUNCTION_PASS(AvgPool2d) + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/nn_AvgPool3d.cpp b/tools/pnnx/src/pass_onnx/nn_AvgPool3d.cpp new file mode 100644 index 00000000000..ff2a5dd8aad --- /dev/null +++ b/tools/pnnx/src/pass_onnx/nn_AvgPool3d.cpp @@ -0,0 +1,65 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_onnx.h" +#include "ir.h" + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +class AvgPool3d : public FuseFunctionPass +{ +public: + const char* match_type_str() const + { + return "nn.AvgPool3d"; + } + + const char* type_str() const + { + return "nn.AvgPool3d"; + } + + void write(Operator* op, const OnnxFunctionProxy& function) const + { + const OnnxNodeProxy averagepool = function.typed_node("AveragePool"); + + std::vector kernel_shape = averagepool.attribute("kernel_shape"); + std::vector strides = averagepool.attribute("strides"); + std::vector pads = averagepool.attribute("pads"); + int64_t ceil_mode = averagepool.attribute("ceil_mode"); + int64_t count_include_pad = averagepool.attribute("count_include_pad"); + + if (pads.size() == 6) + { + pads = {pads[0], pads[1], pads[2]}; + } + + op->params["kernel_size"] = kernel_shape; + op->params["stride"] = strides; + op->params["padding"] = pads; + op->params["ceil_mode"] = (ceil_mode != 0); + op->params["count_include_pad"] = (count_include_pad != 0); + op->params["divisor_override"] = Parameter(); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_FUNCTION_PASS(AvgPool3d) + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/nn_BatchNorm2d.cpp b/tools/pnnx/src/pass_onnx/nn_BatchNorm2d.cpp new file mode 100644 index 00000000000..c3639904d47 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/nn_BatchNorm2d.cpp @@ -0,0 +1,72 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_onnx.h" +#include "ir.h" + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +class BatchNorm2d : public FuseFunctionPass +{ +public: + const char* match_type_str() const + { + return "nn.BatchNorm2d"; + } + + const char* type_str() const + { + return "nn.BatchNorm2d"; + } + + void write(Operator* op, const OnnxFunctionProxy& function) const + { + float eps; + if (function.has_typed_node("_aten_native_batch_norm_inference_onnx")) + { + const OnnxNodeProxy aten_native_batch_norm_inference_onnx = function.typed_node("_aten_native_batch_norm_inference_onnx"); + eps = aten_native_batch_norm_inference_onnx.attribute("eps"); + } + else + { + const OnnxNodeProxy add_eps = function.named_node("aten_add_5"); + eps = function.find_producer(add_eps.node.input(1)).attribute("value"); + } + + const onnx::TensorProto& running_mean = function.initializer("running_mean"); + const onnx::TensorProto& running_var = function.initializer("running_var"); + + op->params["num_features"] = running_mean.dims(0); + op->params["eps"] = eps; + op->params["affine"] = function.has_initializer("weight") && function.has_initializer("bias"); + + op->attrs["running_mean"] = running_mean; + op->attrs["running_var"] = running_var; + if (function.has_initializer("weight") && function.has_initializer("bias")) + { + op->attrs["weight"] = function.initializer("weight"); + op->attrs["bias"] = function.initializer("bias"); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_FUNCTION_PASS(BatchNorm2d) + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/nn_BatchNorm3d.cpp b/tools/pnnx/src/pass_onnx/nn_BatchNorm3d.cpp new file mode 100644 index 00000000000..0f9405f160a --- /dev/null +++ b/tools/pnnx/src/pass_onnx/nn_BatchNorm3d.cpp @@ -0,0 +1,72 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_onnx.h" +#include "ir.h" + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +class BatchNorm3d : public FuseFunctionPass +{ +public: + const char* match_type_str() const + { + return "nn.BatchNorm3d"; + } + + const char* type_str() const + { + return "nn.BatchNorm3d"; + } + + void write(Operator* op, const OnnxFunctionProxy& function) const + { + float eps; + if (function.has_typed_node("_aten_native_batch_norm_inference_onnx")) + { + const OnnxNodeProxy aten_native_batch_norm_inference_onnx = function.typed_node("_aten_native_batch_norm_inference_onnx"); + eps = aten_native_batch_norm_inference_onnx.attribute("eps"); + } + else + { + const OnnxNodeProxy add_eps = function.named_node("aten_add_5"); + eps = function.find_producer(add_eps.node.input(1)).attribute("value"); + } + + const onnx::TensorProto& running_mean = function.initializer("running_mean"); + const onnx::TensorProto& running_var = function.initializer("running_var"); + + op->params["num_features"] = running_mean.dims(0); + op->params["eps"] = eps; + op->params["affine"] = function.has_initializer("weight") && function.has_initializer("bias"); + + op->attrs["running_mean"] = running_mean; + op->attrs["running_var"] = running_var; + if (function.has_initializer("weight") && function.has_initializer("bias")) + { + op->attrs["weight"] = function.initializer("weight"); + op->attrs["bias"] = function.initializer("bias"); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_FUNCTION_PASS(BatchNorm3d) + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/nn_Conv2d.cpp b/tools/pnnx/src/pass_onnx/nn_Conv2d.cpp new file mode 100644 index 00000000000..c9aeac561ac --- /dev/null +++ b/tools/pnnx/src/pass_onnx/nn_Conv2d.cpp @@ -0,0 +1,75 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_onnx.h" +#include "ir.h" + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +class Conv2d : public FuseFunctionPass +{ +public: + const char* match_type_str() const + { + return "nn.Conv2d"; + } + + const char* type_str() const + { + return "nn.Conv2d"; + } + + void write(Operator* op, const OnnxFunctionProxy& function) const + { + const OnnxNodeProxy aten_convolution_onnx = function.typed_node("_aten_convolution_onnx"); + + std::vector dilations = aten_convolution_onnx.attribute("dilations"); + std::vector strides = aten_convolution_onnx.attribute("strides"); + std::vector pads = aten_convolution_onnx.attribute("pads"); + int64_t groups = aten_convolution_onnx.attribute("groups"); + + const onnx::TensorProto& weight = function.initializer("weight"); + + if (pads.size() == 4) + { + pads = {pads[0], pads[1]}; + } + + op->params["in_channels"] = weight.dims(1) * groups; + op->params["out_channels"] = weight.dims(0); + op->params["kernel_size"] = {weight.dims(2), weight.dims(3)}; + op->params["dilation"] = dilations; + op->params["stride"] = strides; + op->params["padding"] = pads; + op->params["groups"] = groups; + op->params["bias"] = function.has_initializer("bias"); + op->params["padding_mode"] = "zeros"; + + op->attrs["weight"] = weight; + if (function.has_initializer("bias")) + { + op->attrs["bias"] = function.initializer("bias"); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_FUNCTION_PASS(Conv2d) + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/nn_Conv3d.cpp b/tools/pnnx/src/pass_onnx/nn_Conv3d.cpp new file mode 100644 index 00000000000..6413685fcb5 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/nn_Conv3d.cpp @@ -0,0 +1,75 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_onnx.h" +#include "ir.h" + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +class Conv3d : public FuseFunctionPass +{ +public: + const char* match_type_str() const + { + return "nn.Conv3d"; + } + + const char* type_str() const + { + return "nn.Conv3d"; + } + + void write(Operator* op, const OnnxFunctionProxy& function) const + { + const OnnxNodeProxy aten_convolution_onnx = function.typed_node("_aten_convolution_onnx"); + + std::vector dilations = aten_convolution_onnx.attribute("dilations"); + std::vector strides = aten_convolution_onnx.attribute("strides"); + std::vector pads = aten_convolution_onnx.attribute("pads"); + int64_t groups = aten_convolution_onnx.attribute("groups"); + + const onnx::TensorProto& weight = function.initializer("weight"); + + if (pads.size() == 6) + { + pads = {pads[0], pads[1], pads[2]}; + } + + op->params["in_channels"] = weight.dims(1) * groups; + op->params["out_channels"] = weight.dims(0); + op->params["kernel_size"] = {weight.dims(2), weight.dims(3), weight.dims(4)}; + op->params["dilation"] = dilations; + op->params["stride"] = strides; + op->params["padding"] = pads; + op->params["groups"] = groups; + op->params["bias"] = function.has_initializer("bias"); + op->params["padding_mode"] = "zeros"; + + op->attrs["weight"] = weight; + if (function.has_initializer("bias")) + { + op->attrs["bias"] = function.initializer("bias"); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_FUNCTION_PASS(Conv3d) + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/nn_GELU.cpp b/tools/pnnx/src/pass_onnx/nn_GELU.cpp new file mode 100644 index 00000000000..f5b7000e017 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/nn_GELU.cpp @@ -0,0 +1,54 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_onnx.h" +#include "ir.h" + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +class GELU : public FuseFunctionPass +{ +public: + const char* match_type_str() const + { + return "nn.GELU"; + } + + const char* type_str() const + { + return "nn.GELU"; + } + + void write(Operator* op, const OnnxFunctionProxy& function) const + { + bool approximate_none = function.has_typed_node("_aten_gelu_approximate_none"); + bool approximate_tanh = function.has_typed_node("_aten_gelu_approximate_tanh"); + + if (approximate_none) + op->params["approximate"] = "none"; + + if (approximate_tanh) + op->params["approximate"] = "tanh"; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_FUNCTION_PASS(GELU) + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/nn_LayerNorm.cpp b/tools/pnnx/src/pass_onnx/nn_LayerNorm.cpp new file mode 100644 index 00000000000..f4ecf289557 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/nn_LayerNorm.cpp @@ -0,0 +1,72 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_onnx.h" +#include "ir.h" + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +class LayerNorm : public FuseFunctionPass +{ +public: + const char* match_type_str() const + { + return "nn.LayerNorm"; + } + + const char* type_str() const + { + return "nn.LayerNorm"; + } + + void write(Operator* op, const OnnxFunctionProxy& function) const + { + const int input_rank = op->inputs[0]->shape.size(); + + const OnnxNodeProxy layernormalization = function.typed_node("LayerNormalization"); + + int64_t axis = layernormalization.attribute("axis"); + + if (axis < 0) + { + axis = input_rank + axis; + } + + std::vector normalized_shape; + for (int i = axis; i < input_rank; i++) + { + normalized_shape.push_back(op->inputs[0]->shape[i]); + } + + op->params["normalized_shape"] = normalized_shape; + op->params["eps"] = layernormalization.attribute("epsilon"); + op->params["elementwise_affine"] = function.has_initializer("weight") && function.has_initializer("bias"); + + if (function.has_initializer("weight") && function.has_initializer("bias")) + { + op->attrs["weight"] = function.initializer("weight"); + op->attrs["bias"] = function.initializer("bias"); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_FUNCTION_PASS(LayerNorm) + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/nn_Linear.cpp b/tools/pnnx/src/pass_onnx/nn_Linear.cpp new file mode 100644 index 00000000000..4dce81908b2 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/nn_Linear.cpp @@ -0,0 +1,57 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_onnx.h" +#include "ir.h" + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +class Linear : public FuseFunctionPass +{ +public: + const char* match_type_str() const + { + return "nn.Linear"; + } + + const char* type_str() const + { + return "nn.Linear"; + } + + void write(Operator* op, const OnnxFunctionProxy& function) const + { + const onnx::TensorProto& weight = function.initializer("weight"); + + op->params["in_features"] = weight.dims(1); + op->params["out_features"] = weight.dims(0); + op->params["bias"] = function.has_initializer("bias"); + + op->attrs["weight"] = weight; + if (function.has_initializer("bias")) + { + op->attrs["bias"] = function.initializer("bias"); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_FUNCTION_PASS(Linear) + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/nn_MaxPool2d.cpp b/tools/pnnx/src/pass_onnx/nn_MaxPool2d.cpp new file mode 100644 index 00000000000..47924bd33fc --- /dev/null +++ b/tools/pnnx/src/pass_onnx/nn_MaxPool2d.cpp @@ -0,0 +1,65 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_onnx.h" +#include "ir.h" + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +class MaxPool2d : public FuseFunctionPass +{ +public: + const char* match_type_str() const + { + return "nn.MaxPool2d"; + } + + const char* type_str() const + { + return "nn.MaxPool2d"; + } + + void write(Operator* op, const OnnxFunctionProxy& function) const + { + const OnnxNodeProxy aten_max_pool_with_indices_onnx = function.typed_node("_aten_max_pool_with_indices_onnx"); + + std::vector kernel_size = aten_max_pool_with_indices_onnx.attribute("kernel_size"); + std::vector dilation = aten_max_pool_with_indices_onnx.attribute("dilation"); + std::vector stride = aten_max_pool_with_indices_onnx.attribute("stride"); + std::vector padding = aten_max_pool_with_indices_onnx.attribute("padding"); + int64_t ceil_mode = aten_max_pool_with_indices_onnx.attribute("ceil_mode"); + + if (padding.size() == 4) + { + padding = {padding[0], padding[1]}; + } + + op->params["kernel_size"] = kernel_size; + op->params["dilation"] = dilation; + op->params["stride"] = stride; + op->params["padding"] = padding; + op->params["ceil_mode"] = (ceil_mode != 0); + op->params["return_indices"] = (function.function.output_size() != 1); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_FUNCTION_PASS(MaxPool2d) + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/nn_MaxPool3d.cpp b/tools/pnnx/src/pass_onnx/nn_MaxPool3d.cpp new file mode 100644 index 00000000000..c8c467f5ba2 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/nn_MaxPool3d.cpp @@ -0,0 +1,65 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_onnx.h" +#include "ir.h" + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +class MaxPool3d : public FuseFunctionPass +{ +public: + const char* match_type_str() const + { + return "nn.MaxPool3d"; + } + + const char* type_str() const + { + return "nn.MaxPool3d"; + } + + void write(Operator* op, const OnnxFunctionProxy& function) const + { + const OnnxNodeProxy aten_max_pool_with_indices_onnx = function.typed_node("_aten_max_pool_with_indices_onnx"); + + std::vector kernel_size = aten_max_pool_with_indices_onnx.attribute("kernel_size"); + std::vector dilation = aten_max_pool_with_indices_onnx.attribute("dilation"); + std::vector stride = aten_max_pool_with_indices_onnx.attribute("stride"); + std::vector padding = aten_max_pool_with_indices_onnx.attribute("padding"); + int64_t ceil_mode = aten_max_pool_with_indices_onnx.attribute("ceil_mode"); + + if (padding.size() == 6) + { + padding = {padding[0], padding[1], padding[2]}; + } + + op->params["kernel_size"] = kernel_size; + op->params["dilation"] = dilation; + op->params["stride"] = stride; + op->params["padding"] = padding; + op->params["ceil_mode"] = (ceil_mode != 0); + op->params["return_indices"] = (function.function.output_size() != 1); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_FUNCTION_PASS(MaxPool3d) + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/nn_MultiheadAttention.cpp b/tools/pnnx/src/pass_onnx/nn_MultiheadAttention.cpp new file mode 100644 index 00000000000..a29ec9d9306 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/nn_MultiheadAttention.cpp @@ -0,0 +1,122 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_onnx.h" +#include "ir.h" + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +class MultiheadAttention : public FuseFunctionPass +{ +public: + const char* match_type_str() const + { + return "nn.MultiheadAttention"; + } + + const char* type_str() const + { + return "nn.MultiheadAttention"; + } + + void write(Operator* op, const OnnxFunctionProxy& function) const + { + const OnnxNodeProxy attention_scale = function.typed_node("_attention_scale"); + + const OnnxNodeProxy reshape_heads = function.find_producer(attention_scale.node.input(0)); + + const OnnxNodeProxy constant_shape = function.find_producer(reshape_heads.node.input(1)); + + if (constant_shape.node.op_type() == "Constant") + { + std::vector shape = constant_shape.attribute("value"); + op->params["num_heads"] = shape[1]; + } + + const OnnxNodeProxy transpose = function.typed_node("Transpose"); + std::vector perm = transpose.attribute("perm"); + if (perm == std::vector {1, 0, 2}) + { + op->params["batch_first"] = true; + } + else + { + op->params["batch_first"] = false; + } + + op->params["add_zero_attn"] = false; // TODO + + if (function.has_typed_node("_aten_scaled_dot_product_attention_no_mask_onnx")) + { + // TODO handle attn_mask + } + + if (function.has_initializer("in_proj_weight")) + { + const onnx::TensorProto& in_proj_weight = function.initializer("in_proj_weight"); + + op->params["embed_dim"] = in_proj_weight.dims(1); + op->params["kdim"] = in_proj_weight.dims(1); + op->params["vdim"] = in_proj_weight.dims(1); + op->attrs["in_proj_weight"] = in_proj_weight; + } + else + { + const onnx::TensorProto& q_proj_weight = function.initializer("q_proj_weight"); + const onnx::TensorProto& k_proj_weight = function.initializer("k_proj_weight"); + const onnx::TensorProto& v_proj_weight = function.initializer("v_proj_weight"); + + op->params["embed_dim"] = q_proj_weight.dims(1); + op->params["kdim"] = k_proj_weight.dims(1); + op->params["vdim"] = v_proj_weight.dims(1); + op->attrs["q_proj_weight"] = q_proj_weight; + op->attrs["k_proj_weight"] = k_proj_weight; + op->attrs["v_proj_weight"] = v_proj_weight; + } + + op->attrs["out_proj.weight"] = function.initializer("weight"); + + if (function.has_initializer("in_proj_bias") && function.has_initializer("bias")) + { + op->params["bias"] = true; + op->attrs["in_proj_bias"] = function.initializer("in_proj_bias"); + op->attrs["out_proj.bias"] = function.initializer("bias"); + } + else + { + op->params["bias"] = false; + } + + if (function.has_initializer("bias_k") && function.has_initializer("bias_v")) + { + op->params["add_bias_kv"] = true; + op->attrs["bias_k"] = function.initializer("bias_k"); + op->attrs["bias_v"] = function.initializer("bias_v"); + } + else + { + op->params["add_bias_kv"] = false; + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_FUNCTION_PASS(MultiheadAttention) + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/shape_inference.cpp b/tools/pnnx/src/pass_onnx/shape_inference.cpp new file mode 100644 index 00000000000..fb72e6b8513 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/shape_inference.cpp @@ -0,0 +1,340 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "shape_inference.h" + +#include +#include +#include + +#include + +namespace pnnx { + +namespace onnx2pnnx { + +static bool string_starts_with(const std::string& s, const std::string& s2) +{ + return strncmp(s.c_str(), s2.c_str(), s2.size()) == 0; +} + +void shape_inference(onnx::ModelProto& model) +{ + onnx::GraphProto* graph = model.mutable_graph(); + + // save original outputs + std::vector orig_outputs; + { + for (int i = 0; i < graph->output_size(); i++) + { + orig_outputs.push_back(graph->output(i).name()); + } + } + + // collect intermediates + std::vector intermediates; + { + for (int i = 0; i < graph->node_size(); i++) + { + const onnx::NodeProto& node = graph->node(i); + + const std::string& op_type = node.op_type(); + + // blacklist some glues + if (op_type == "Constant") + continue; + + // TODO fuse cat + if (op_type == "SequenceConstruct") + continue; + + // TODO fuse chunk/tensor_split + if (op_type == "aten_split") + continue; + + if (node.domain().empty() || string_starts_with(op_type, "nn_") || string_starts_with(op_type, "aten_") || string_starts_with(op_type, "_aten_")) + { + for (int j = 0; j < node.output_size(); j++) + { + // some input/output may have empty name, it causes trouble, skip it + if (node.output(j).empty()) + continue; + + intermediates.push_back(node.output(j)); + } + } + } + } + + // add intermediates to onnx output + { + graph->clear_output(); + + for (size_t i = 0; i < intermediates.size(); i++) + { + graph->add_output()->set_name(intermediates[i]); + } + } + + // generate temp onnx graph + std::string tmp_onnx_data; + { + std::stringstream tmp_onnx_data_ss; + if (!model.SerializeToOstream(&tmp_onnx_data_ss)) + { + fprintf(stderr, "write onnx failed\n"); + return; + } + + tmp_onnx_data = tmp_onnx_data_ss.str(); + } + + // onnxrt inference + { + const OrtApi* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + + OrtStatus* ort_status = 0; + + OrtEnv* ort_env = 0; + ort_status = ort_api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "pnnx", &ort_env); + if (ort_status) + { + fprintf(stderr, "ort CreateEnv failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + OrtSessionOptions* ort_session_opt = 0; + ort_status = ort_api->CreateSessionOptions(&ort_session_opt); + if (ort_status) + { + fprintf(stderr, "ort CreateSessionOptions failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + ort_status = ort_api->SetSessionGraphOptimizationLevel(ort_session_opt, ORT_DISABLE_ALL); + if (ort_status) + { + fprintf(stderr, "ort SetSessionGraphOptimizationLevel failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + // ort_status = ort_api->SetIntraOpNumThreads(ort_session_opt, 4); + // if (ort_status) + // { + // fprintf(stderr, "ort SetIntraOpNumThreads failed %s\n", ort_api->GetErrorMessage(ort_status)); + // } + // + // ort_status = ort_api->SetInterOpNumThreads(ort_session_opt, 4); + // if (ort_status) + // { + // fprintf(stderr, "ort SetInterOpNumThreads failed %s\n", ort_api->GetErrorMessage(ort_status)); + // } + + OrtSession* ort_session = 0; + ort_status = ort_api->CreateSessionFromArray(ort_env, (const void*)tmp_onnx_data.data(), tmp_onnx_data.size(), ort_session_opt, &ort_session); + if (ort_status) + { + fprintf(stderr, "ort CreateSession failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + OrtRunOptions* ort_run_opt = 0; + ort_status = ort_api->CreateRunOptions(&ort_run_opt); + if (ort_status) + { + fprintf(stderr, "ort CreateRunOptions failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + OrtAllocator* ort_allocator = 0; + ort_status = ort_api->GetAllocatorWithDefaultOptions(&ort_allocator); + if (ort_status) + { + fprintf(stderr, "ort GetAllocatorWithDefaultOptions failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + std::vector input_names; + std::vector inputs; + for (int i = 0; i < graph->input_size(); i++) + { + const onnx::ValueInfoProto& value = graph->input(i); + + std::vector shape; + const onnx::TensorShapeProto& tsp = value.type().tensor_type().shape(); + for (int k = 0; k < tsp.dim_size(); k++) + { + // TODO has_dim_value ? + shape.push_back(tsp.dim(k).dim_value()); + } + + ONNXTensorElementDataType datatype = (ONNXTensorElementDataType)value.type().tensor_type().elem_type(); + + OrtValue* ort_val = 0; + ort_status = ort_api->CreateTensorAsOrtValue(ort_allocator, (const int64_t*)shape.data(), shape.size(), datatype, &ort_val); + if (ort_status) + { + fprintf(stderr, "ort CreateTensorAsOrtValue failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + input_names.push_back(value.name().c_str()); + inputs.push_back(ort_val); + } + + std::vector output_names; + std::vector outputs; + for (size_t i = 0; i < intermediates.size(); i++) + { + output_names.push_back(intermediates[i].c_str()); + outputs.push_back(0); + } + + ort_status = ort_api->Run(ort_session, ort_run_opt, + input_names.data(), inputs.data(), input_names.size(), + output_names.data(), output_names.size(), outputs.data()); + if (ort_status) + { + fprintf(stderr, "ort Run failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + // TODO get output data + + graph->clear_output(); + + for (size_t i = 0; i < output_names.size(); i++) + { + OrtTypeInfo* type_info = 0; + ort_status = ort_api->GetTypeInfo(outputs[i], &type_info); + if (ort_status) + { + fprintf(stderr, "ort GetTypeInfo failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + ONNXType type = ONNX_TYPE_UNKNOWN; + if (type_info) + { + ort_status = ort_api->GetOnnxTypeFromTypeInfo(type_info, &type); + if (ort_status) + { + fprintf(stderr, "ort GetOnnxTypeFromTypeInfo failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + } + + if (type == ONNX_TYPE_TENSOR) + { + OrtTensorTypeAndShapeInfo* info = 0; + ort_status = ort_api->GetTensorTypeAndShape(outputs[i], &info); + if (ort_status) + { + fprintf(stderr, "ort GetTensorTypeAndShape failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + ONNXTensorElementDataType datatype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ort_status = ort_api->GetTensorElementType(info, &datatype); + if (ort_status) + { + fprintf(stderr, "ort GetTensorElementType failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + size_t out_dims = 0; + ort_status = ort_api->GetDimensionsCount(info, &out_dims); + if (ort_status) + { + fprintf(stderr, "ort GetDimensionsCount failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + // fprintf(stderr, " out_dims = %lu\n", out_dims); + + std::vector out_shape; + out_shape.resize(out_dims); + ort_status = ort_api->GetDimensions(info, out_shape.data(), out_dims); + if (ort_status) + { + fprintf(stderr, "ort GetDimensions failed %s\n", ort_api->GetErrorMessage(ort_status)); + } + + // fprintf(stderr, "%16s = ", output_names[i]); + // for (size_t j = 0; j < out_dims; j++) + // { + // fprintf(stderr, "%lu ", out_shape[j]); + // } + // fprintf(stderr, "\n"); + + // assign value info + { + onnx::ValueInfoProto* value = 0; + + // maybe output + for (size_t j = 0; j < orig_outputs.size(); j++) + { + if (orig_outputs[j] == output_names[i]) + { + value = graph->add_output(); + value->set_name(output_names[i]); + break; + } + } + if (!value) + { + for (int j = 0; j < graph->value_info_size(); j++) + { + if (graph->mutable_value_info(j)->name() == output_names[i]) + { + value = graph->mutable_value_info(j); + break; + } + } + if (!value) + { + value = graph->add_value_info(); + value->set_name(output_names[i]); + } + } + + // fprintf(stderr, "assign value info %s\n", value->name().c_str()); + + value->mutable_type()->mutable_tensor_type()->set_elem_type((int32_t)datatype); + + onnx::TensorShapeProto* tsp = value->mutable_type()->mutable_tensor_type()->mutable_shape(); + + tsp->clear_dim(); + for (size_t j = 0; j < out_dims; j++) + { + tsp->add_dim()->set_dim_value(out_shape[j]); + } + } + + ort_api->ReleaseTensorTypeAndShapeInfo(info); + } + + if (type_info) + { + ort_api->ReleaseTypeInfo(type_info); + } + } + + for (size_t i = 0; i < input_names.size(); i++) + { + ort_api->ReleaseValue(inputs[i]); + } + + for (size_t i = 0; i < output_names.size(); i++) + { + ort_api->ReleaseValue(outputs[i]); + } + + ort_api->ReleaseRunOptions(ort_run_opt); + ort_api->ReleaseSession(ort_session); + ort_api->ReleaseSessionOptions(ort_session_opt); + ort_api->ReleaseEnv(ort_env); + } +} + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/shape_inference.h b/tools/pnnx/src/pass_onnx/shape_inference.h new file mode 100644 index 00000000000..ea87333451d --- /dev/null +++ b/tools/pnnx/src/pass_onnx/shape_inference.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "onnx.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +void shape_inference(onnx::ModelProto& model); + +} // namespace onnx2pnnx + +} // namespace pnnx