Skip to content

Commit

Permalink
[Enhancement] Install Optimizer by setuptools (#690)
Browse files Browse the repository at this point in the history
* Add fuse select assign pass

* move code to csrc

* add config flag

* Add fuse select assign pass

* Add CSE for ONNX

* remove useless code

* Install optimizer by setup tools

* fix comment
  • Loading branch information
q.yao authored Jul 25, 2022
1 parent 36c35b6 commit 5b31d7a
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 10 deletions.
3 changes: 3 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ include mmdeploy/backend/ncnn/*.pyd
include mmdeploy/lib/*.so
include mmdeploy/lib/*.dll
include mmdeploy/lib/*.pyd
include mmdeploy/backend/torchscript/*.so
include mmdeploy/backend/torchscript/*.dll
include mmdeploy/backend/torchscript/*.pyd
1 change: 0 additions & 1 deletion csrc/mmdeploy/backend_ops/torchscript/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.

add_subdirectory(ops)
add_subdirectory(optimizer)
2 changes: 1 addition & 1 deletion mmdeploy/apis/onnx/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def model_to_graph__custom_optimizer(ctx, *args, **kwargs):
assert isinstance(
custom_passes, Callable
), f'Expect a callable onnx_custom_passes, get {type(custom_passes)}.'
graph, params_dict, torch_out = custom_passes(graph, params_dict,
graph, params_dict, torch_out = custom_passes(ctx, graph, params_dict,
torch_out)

return graph, params_dict, torch_out
Expand Down
3 changes: 2 additions & 1 deletion mmdeploy/apis/onnx/passes/optimize_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from mmdeploy.utils import get_root_logger


def optimize_onnx(graph, params_dict, torch_out):
def optimize_onnx(ctx, graph, params_dict, torch_out):
"""The optimize callback of the onnx model."""
logger = get_root_logger()
logger.info('Execute onnx optimize passes.')
try:
Expand Down
74 changes: 72 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

from setuptools import find_packages, setup

try:
from torch.utils.cpp_extension import BuildExtension
cmd_class = {'build_ext': BuildExtension}
except ModuleNotFoundError:
cmd_class = {}
print('Skip building ext ops due to the absence of torch.')
pwd = os.path.dirname(__file__)
version_file = 'mmdeploy/version.py'

Expand Down Expand Up @@ -96,6 +102,70 @@ def gen_packages_items():
return packages


def get_extensions():
extensions = []
ext_name = 'mmdeploy.backend.torchscript.ts_optimizer'
import glob
import platform

from torch.utils.cpp_extension import CppExtension

try:
import psutil
num_cpu = len(psutil.Process().cpu_affinity())
cpu_use = max(4, num_cpu - 1)
except (ModuleNotFoundError, AttributeError):
cpu_use = 4

os.environ.setdefault('MAX_JOBS', str(cpu_use))
define_macros = []

# Before PyTorch1.8.0, when compiling CUDA code, `cxx` is a
# required key passed to PyTorch. Even if there is no flag passed
# to cxx, users also need to pass an empty list to PyTorch.
# Since PyTorch1.8.0, it has a default value so users do not need
# to pass an empty list anymore.
# More details at https://github.com/pytorch/pytorch/pull/45956
extra_compile_args = {'cxx': []}

# c++14 is required.
# However, in the windows environment, some standard libraries
# will depend on c++17 or higher. In fact, for the windows
# environment, the compiler will choose the appropriate compiler
# to compile those cpp files, so there is no need to add the
# argument
if platform.system() != 'Windows':
extra_compile_args['cxx'] = ['-std=c++14']

include_dirs = []

op_files = glob.glob(
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/*.cpp'
) + glob.glob(
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/*.cpp'
) + glob.glob(
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/*.cpp')
extension = CppExtension

# c++14 is required.
# However, in the windows environment, some standard libraries
# will depend on c++17 or higher. In fact, for the windows
# environment, the compiler will choose the appropriate compiler
# to compile those cpp files, so there is no need to add the
# argument
if 'nvcc' in extra_compile_args and platform.system() != 'Windows':
extra_compile_args['nvcc'] += ['-std=c++14']

ext_ops = extension(
name=ext_name,
sources=op_files,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args)
extensions.append(ext_ops)
return extensions


if __name__ == '__main__':
setup(
name='mmdeploy',
Expand Down Expand Up @@ -128,6 +198,6 @@ def gen_packages_items():
'build': parse_requirements('requirements/build.txt'),
'optional': parse_requirements('requirements/optional.txt'),
},
ext_modules=[],
cmdclass={},
ext_modules=get_extensions(),
cmdclass=cmd_class,
zip_safe=False)
10 changes: 5 additions & 5 deletions tests/test_apis/test_onnx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_merge_shape_concate():
except ImportError:
pytest.skip('pass not found.')

def _optimize_onnx(graph, params_dict, torch_out):
def _optimize_onnx(ctx, graph, params_dict, torch_out):
opt_pass(graph)
return graph, params_dict, torch_out

Expand Down Expand Up @@ -82,7 +82,7 @@ def test_peephole():
except ImportError:
pytest.skip('pass not found.')

def _optimize_onnx(graph, params_dict, torch_out):
def _optimize_onnx(ctx, graph, params_dict, torch_out):
opt_pass(graph)
return graph, params_dict, torch_out

Expand Down Expand Up @@ -148,7 +148,7 @@ def test_flatten_cls_head():
except ImportError:
pytest.skip('pass not found.')

def _optimize_onnx(graph, params_dict, torch_out):
def _optimize_onnx(ctx, graph, params_dict, torch_out):
opt_pass(graph)
return graph, params_dict, torch_out

Expand Down Expand Up @@ -199,7 +199,7 @@ def test_fuse_select_assign():
except ImportError:
pytest.skip('pass not found.')

def _optimize_onnx(graph, params_dict, torch_out):
def _optimize_onnx(ctx, graph, params_dict, torch_out):
opt_pass(graph, params_dict)
return graph, params_dict, torch_out

Expand Down Expand Up @@ -247,7 +247,7 @@ def test_common_subgraph_elimination():
except ImportError:
pytest.skip('pass not found.')

def _optimize_onnx(graph, params_dict, torch_out):
def _optimize_onnx(ctx, graph, params_dict, torch_out):
opt_pass(graph, params_dict)
return graph, params_dict, torch_out

Expand Down
6 changes: 6 additions & 0 deletions tools/package_tools/mmdeploy_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def _remove_in_mmdeploy(path):
for ncnn_ext_path in ncnn_ext_paths:
os.remove(ncnn_ext_path)

# remove ts_optmizer
ts_optimizer_paths = glob(
osp.join(mmdeploy_dir, 'mmdeploy/backend/torchscript/ts_optimizer.*'))
for ts_optimizer_path in ts_optimizer_paths:
os.remove(ts_optimizer_path)


def build_mmdeploy(cfg, mmdeploy_dir, dist_dir=None):
cmake_flags = cfg.get('cmake_flags', [])
Expand Down

0 comments on commit 5b31d7a

Please sign in to comment.