diff --git a/MANIFEST.in b/MANIFEST.in index 7c85a3240b..f3427de2fe 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -5,3 +5,6 @@ include mmdeploy/backend/ncnn/*.pyd include mmdeploy/lib/*.so include mmdeploy/lib/*.dll include mmdeploy/lib/*.pyd +include mmdeploy/backend/torchscript/*.so +include mmdeploy/backend/torchscript/*.dll +include mmdeploy/backend/torchscript/*.pyd diff --git a/csrc/mmdeploy/backend_ops/torchscript/CMakeLists.txt b/csrc/mmdeploy/backend_ops/torchscript/CMakeLists.txt index 8d862b9411..4b080f621a 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/CMakeLists.txt +++ b/csrc/mmdeploy/backend_ops/torchscript/CMakeLists.txt @@ -1,4 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. add_subdirectory(ops) -add_subdirectory(optimizer) diff --git a/mmdeploy/apis/onnx/optimizer.py b/mmdeploy/apis/onnx/optimizer.py index 40af5a888f..b9d2ead0c0 100644 --- a/mmdeploy/apis/onnx/optimizer.py +++ b/mmdeploy/apis/onnx/optimizer.py @@ -15,7 +15,7 @@ def model_to_graph__custom_optimizer(ctx, *args, **kwargs): assert isinstance( custom_passes, Callable ), f'Expect a callable onnx_custom_passes, get {type(custom_passes)}.' - graph, params_dict, torch_out = custom_passes(graph, params_dict, + graph, params_dict, torch_out = custom_passes(ctx, graph, params_dict, torch_out) return graph, params_dict, torch_out diff --git a/mmdeploy/apis/onnx/passes/optimize_onnx.py b/mmdeploy/apis/onnx/passes/optimize_onnx.py index 8b12c6bf92..19e14bc292 100644 --- a/mmdeploy/apis/onnx/passes/optimize_onnx.py +++ b/mmdeploy/apis/onnx/passes/optimize_onnx.py @@ -2,7 +2,8 @@ from mmdeploy.utils import get_root_logger -def optimize_onnx(graph, params_dict, torch_out): +def optimize_onnx(ctx, graph, params_dict, torch_out): + """The optimize callback of the onnx model.""" logger = get_root_logger() logger.info('Execute onnx optimize passes.') try: diff --git a/setup.py b/setup.py index 86e5cdf022..634423a5ae 100644 --- a/setup.py +++ b/setup.py @@ -2,6 +2,12 @@ from setuptools import find_packages, setup +try: + from torch.utils.cpp_extension import BuildExtension + cmd_class = {'build_ext': BuildExtension} +except ModuleNotFoundError: + cmd_class = {} + print('Skip building ext ops due to the absence of torch.') pwd = os.path.dirname(__file__) version_file = 'mmdeploy/version.py' @@ -96,6 +102,70 @@ def gen_packages_items(): return packages +def get_extensions(): + extensions = [] + ext_name = 'mmdeploy.backend.torchscript.ts_optimizer' + import glob + import platform + + from torch.utils.cpp_extension import CppExtension + + try: + import psutil + num_cpu = len(psutil.Process().cpu_affinity()) + cpu_use = max(4, num_cpu - 1) + except (ModuleNotFoundError, AttributeError): + cpu_use = 4 + + os.environ.setdefault('MAX_JOBS', str(cpu_use)) + define_macros = [] + + # Before PyTorch1.8.0, when compiling CUDA code, `cxx` is a + # required key passed to PyTorch. Even if there is no flag passed + # to cxx, users also need to pass an empty list to PyTorch. + # Since PyTorch1.8.0, it has a default value so users do not need + # to pass an empty list anymore. + # More details at https://github.com/pytorch/pytorch/pull/45956 + extra_compile_args = {'cxx': []} + + # c++14 is required. + # However, in the windows environment, some standard libraries + # will depend on c++17 or higher. In fact, for the windows + # environment, the compiler will choose the appropriate compiler + # to compile those cpp files, so there is no need to add the + # argument + if platform.system() != 'Windows': + extra_compile_args['cxx'] = ['-std=c++14'] + + include_dirs = [] + + op_files = glob.glob( + './csrc/mmdeploy/backend_ops/torchscript/optimizer/*.cpp' + ) + glob.glob( + './csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/*.cpp' + ) + glob.glob( + './csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/*.cpp') + extension = CppExtension + + # c++14 is required. + # However, in the windows environment, some standard libraries + # will depend on c++17 or higher. In fact, for the windows + # environment, the compiler will choose the appropriate compiler + # to compile those cpp files, so there is no need to add the + # argument + if 'nvcc' in extra_compile_args and platform.system() != 'Windows': + extra_compile_args['nvcc'] += ['-std=c++14'] + + ext_ops = extension( + name=ext_name, + sources=op_files, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args) + extensions.append(ext_ops) + return extensions + + if __name__ == '__main__': setup( name='mmdeploy', @@ -128,6 +198,6 @@ def gen_packages_items(): 'build': parse_requirements('requirements/build.txt'), 'optional': parse_requirements('requirements/optional.txt'), }, - ext_modules=[], - cmdclass={}, + ext_modules=get_extensions(), + cmdclass=cmd_class, zip_safe=False) diff --git a/tests/test_apis/test_onnx_passes.py b/tests/test_apis/test_onnx_passes.py index 420ea2572f..a2d77b4463 100644 --- a/tests/test_apis/test_onnx_passes.py +++ b/tests/test_apis/test_onnx_passes.py @@ -30,7 +30,7 @@ def test_merge_shape_concate(): except ImportError: pytest.skip('pass not found.') - def _optimize_onnx(graph, params_dict, torch_out): + def _optimize_onnx(ctx, graph, params_dict, torch_out): opt_pass(graph) return graph, params_dict, torch_out @@ -82,7 +82,7 @@ def test_peephole(): except ImportError: pytest.skip('pass not found.') - def _optimize_onnx(graph, params_dict, torch_out): + def _optimize_onnx(ctx, graph, params_dict, torch_out): opt_pass(graph) return graph, params_dict, torch_out @@ -148,7 +148,7 @@ def test_flatten_cls_head(): except ImportError: pytest.skip('pass not found.') - def _optimize_onnx(graph, params_dict, torch_out): + def _optimize_onnx(ctx, graph, params_dict, torch_out): opt_pass(graph) return graph, params_dict, torch_out @@ -199,7 +199,7 @@ def test_fuse_select_assign(): except ImportError: pytest.skip('pass not found.') - def _optimize_onnx(graph, params_dict, torch_out): + def _optimize_onnx(ctx, graph, params_dict, torch_out): opt_pass(graph, params_dict) return graph, params_dict, torch_out @@ -247,7 +247,7 @@ def test_common_subgraph_elimination(): except ImportError: pytest.skip('pass not found.') - def _optimize_onnx(graph, params_dict, torch_out): + def _optimize_onnx(ctx, graph, params_dict, torch_out): opt_pass(graph, params_dict) return graph, params_dict, torch_out diff --git a/tools/package_tools/mmdeploy_builder.py b/tools/package_tools/mmdeploy_builder.py index b84a0daf42..bdf70ed2e3 100644 --- a/tools/package_tools/mmdeploy_builder.py +++ b/tools/package_tools/mmdeploy_builder.py @@ -133,6 +133,12 @@ def _remove_in_mmdeploy(path): for ncnn_ext_path in ncnn_ext_paths: os.remove(ncnn_ext_path) + # remove ts_optmizer + ts_optimizer_paths = glob( + osp.join(mmdeploy_dir, 'mmdeploy/backend/torchscript/ts_optimizer.*')) + for ts_optimizer_path in ts_optimizer_paths: + os.remove(ts_optimizer_path) + def build_mmdeploy(cfg, mmdeploy_dir, dist_dir=None): cmake_flags = cfg.get('cmake_flags', [])