Skip to content

Commit

Permalink
pnnx load dynamo onnx (#5363)
Browse files Browse the repository at this point in the history
* split load torchscript sources from pnnx, drop cxxabi hack, link torchscript with whole-archive

* check compiler cxx11 abi

* eliminate some onnx noop
  • Loading branch information
nihui authored Mar 19, 2024
1 parent 25c4278 commit a55fe1c
Show file tree
Hide file tree
Showing 57 changed files with 5,552 additions and 13 deletions.
2 changes: 2 additions & 0 deletions .ci/pnnx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ jobs:
- name: test
run: |
export PYTHONUSERBASE=${{ci.workspace}}/torch-${{matrix.torch-version}}
export LD_LIBRARY_PATH=${{ci.workspace}}/torchvision-${{matrix.torchvision-version}}-install/lib
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
export MKL_ENABLE_INSTRUCTIONS=SSE4_2
Expand All @@ -131,6 +132,7 @@ jobs:
- name: python-pnnx
run: |
export PYTHONUSERBASE=${{ci.workspace}}/torch-${{matrix.torch-version}}
export LD_LIBRARY_PATH=${{ci.workspace}}/torchvision-${{matrix.torchvision-version}}-install/lib
export PNNX_WHEEL_WITHOUT_BUILD=ON
cd tools/pnnx/python
cp ../build/src/pnnx pnnx/
Expand Down
15 changes: 6 additions & 9 deletions tools/pnnx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,17 @@ include(PNNXPyTorch)
set(CMAKE_CXX_STANDARD 14)

# set(CMAKE_BUILD_TYPE debug)
#set(CMAKE_BUILD_TYPE relwithdebinfo)
# set(CMAKE_BUILD_TYPE relwithdebinfo)
# set(CMAKE_BUILD_TYPE release)

option(PNNX_COVERAGE "build for coverage" OFF)

#set(Torch_INSTALL_DIR "/home/nihui/.local/lib/python3.9/site-packages/torch" CACHE STRING "")
#set(Torch_INSTALL_DIR "/home/nihui/osd/pnnx/pytorch-v1.10.0/build/install" CACHE STRING "")
# set(Torch_INSTALL_DIR "/home/nihui/osd/pnnx/install" CACHE STRING "")
set(TorchVision_INSTALL_DIR "/home/nihui/osd/vision/build/install" CACHE STRING "")
# set(TorchVision_INSTALL_DIR "/home/nihui/osd/pnnx/install" CACHE STRING "")

#set(Torch_DIR "${Torch_INSTALL_DIR}/share/cmake/Torch")
set(TorchVision_DIR "${TorchVision_INSTALL_DIR}/share/cmake/TorchVision")

# test if libtorch and protobuf has the same cxxabi version

find_package(Python3 COMPONENTS Interpreter Development)

PNNXProbeForPyTorchInstall()
Expand All @@ -67,11 +63,11 @@ find_library(TORCHVISION_LIBRARY torchvision PATHS "${TorchVision_INSTALL_DIR}/l
if(TORCHVISION_LIBRARY)
message(STATUS "Found TorchVision: ${TORCHVISION_LIBRARY}")
if(APPLE)
list(APPEND TORCHVISION_LIBRARY "-Wl,-force_load,${TORCHVISION_LIBRARY}")
set(TORCHVISION_LIBRARY "-Wl,-force_load,${TORCHVISION_LIBRARY}")
elseif(MSVC)
list(APPEND TORCHVISION_LIBRARY "-WHOLEARCHIVE:${TORCHVISION_LIBRARY}")
set(TORCHVISION_LIBRARY "-WHOLEARCHIVE:${TORCHVISION_LIBRARY}")
else()
list(APPEND TORCHVISION_LIBRARY "-Wl,--whole-archive ${TORCHVISION_LIBRARY} -Wl,--no-whole-archive")
set(TORCHVISION_LIBRARY "-Wl,--whole-archive ${TORCHVISION_LIBRARY} -Wl,--no-whole-archive")
endif()
set(TorchVision_FOUND TRUE)
message(STATUS "Building with TorchVision")
Expand All @@ -84,6 +80,7 @@ endif()
include_directories(${TORCH_INCLUDE_DIRS})

if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
# test if libtorch and protobuf has the same cxxabi version
include(CheckCXXSourceCompiles)
set(CMAKE_REQUIRED_FLAGS "${TORCH_CXX_FLAGS}")
check_cxx_source_compiles("#include <cxxabi.h>\n#if _GLIBCXX_USE_CXX11_ABI\nint main() { return 0; }\n#endif" PNNX_TORCH_USE_CXX11_ABI)
Expand Down
20 changes: 18 additions & 2 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -621,20 +621,36 @@ endif()
# set(pnnx_pass_onnx_SRCS
# pass_onnx/canonicalize.cpp
# pass_onnx/dead_code_elimination.cpp
# pass_onnx/eliminate_noop.cpp
# pass_onnx/fold_constants.cpp
# pass_onnx/inline_containers.cpp
# pass_onnx/model_stat.cpp
# pass_onnx/shape_inference.cpp
#
# pass_onnx/nn_AdaptiveAvgPool2d.cpp
# pass_onnx/nn_AdaptiveAvgPool3d.cpp
# pass_onnx/nn_AvgPool2d.cpp
# pass_onnx/nn_AvgPool3d.cpp
# pass_onnx/nn_BatchNorm2d.cpp
# pass_onnx/nn_BatchNorm3d.cpp
# pass_onnx/nn_Conv2d.cpp
# pass_onnx/nn_Conv3d.cpp
# pass_onnx/nn_GELU.cpp
# pass_onnx/nn_LayerNorm.cpp
# pass_onnx/nn_Linear.cpp
# pass_onnx/nn_MaxPool2d.cpp
# pass_onnx/nn_MaxPool3d.cpp
# pass_onnx/nn_MultiheadAttention.cpp
# )
#
# set(onnx2pnnx_SRCS
# pass_onnx.cpp
# ${pnnx_pass_onnx_SRCS}
# load_onnx.cpp
# )
#
# add_library(onnx2pnnx STATIC ${onnx2pnnx_SRCS})
# add_library(onnx2pnnx OBJECT ${onnx2pnnx_SRCS})
# target_link_libraries(onnx2pnnx PRIVATE onnxproto onnxruntime::onnxruntime)
#
# target_compile_definitions(onnx2pnnx PRIVATE BUILD_ONNX2PNNX)
#
# message(STATUS "Building with dynamo-onnx")
Expand Down
5 changes: 3 additions & 2 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ static std::string sanitize_identifier(const std::string& s)
std::string ss = s;
for (size_t i = 0; i < ss.size(); i++)
{
if (ss[i] == '.' || ss[i] == ':')
if (ss[i] == '.' || ss[i] == ':' || ss[i] == '/')
ss[i] = '_';
}

Expand Down Expand Up @@ -2771,7 +2771,8 @@ int Graph::parse(const std::string& param)
void Operand::remove_consumer(const Operator* c)
{
auto it = std::find(consumers.begin(), consumers.end(), c);
consumers.erase(it);
if (it != consumers.end())
consumers.erase(it);
}

Operator* Graph::new_operator(const std::string& type, const std::string& name)
Expand Down
11 changes: 11 additions & 0 deletions tools/pnnx/src/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,15 @@ class Tensor;

#if BUILD_ONNX2PNNX
namespace onnx {
class AttributeProto;
class TensorProto;
class ValueInfoProto;
} // namespace onnx
namespace pnnx {
namespace onnx2pnnx {
class OnnxAttributeProxy;
} // namespace onnx2pnnx
} // namespace pnnx
#endif // BUILD_ONNX2PNNX

namespace pnnx {
Expand Down Expand Up @@ -187,6 +193,10 @@ class Parameter
Parameter(const torch::jit::Node* value_node);
Parameter(const torch::jit::Value* value);
#endif // BUILD_TORCH2PNNX
#if BUILD_ONNX2PNNX
Parameter(const onnx::AttributeProto& attr);
Parameter(const onnx2pnnx::OnnxAttributeProxy& attr);
#endif // BUILD_ONNX2PNNX

static Parameter parse_from_string(const std::string& value);
static std::string encode_to_string(const Parameter& param);
Expand Down Expand Up @@ -325,6 +335,7 @@ class Graph
#endif
#if BUILD_ONNX2PNNX
Operand* new_operand(const onnx::ValueInfoProto& value);
Operand* new_operand(const onnx::TensorProto& t);
#endif

Operand* new_operand(const std::string& name);
Expand Down
Loading

0 comments on commit a55fe1c

Please sign in to comment.