[Enhancement] Install Optimizer by setuptools (#690)

* Add fuse select assign pass

* move code to csrc

* add config flag

* Add fuse select assign pass

* Add CSE for ONNX

* remove useless code

* Install optimizer by setup tools

* fix comment
pull/817/head
q.yao 2022-07-25 13:04:27 +08:00 committed by GitHub
parent 36c35b6e88
commit 5b31d7a60d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 89 additions and 10 deletions

View File

@ -5,3 +5,6 @@ include mmdeploy/backend/ncnn/*.pyd
include mmdeploy/lib/*.so
include mmdeploy/lib/*.dll
include mmdeploy/lib/*.pyd
include mmdeploy/backend/torchscript/*.so
include mmdeploy/backend/torchscript/*.dll
include mmdeploy/backend/torchscript/*.pyd

View File

@ -1,4 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
add_subdirectory(ops)
add_subdirectory(optimizer)

View File

@ -15,7 +15,7 @@ def model_to_graph__custom_optimizer(ctx, *args, **kwargs):
assert isinstance(
custom_passes, Callable
), f'Expect a callable onnx_custom_passes, get {type(custom_passes)}.'
graph, params_dict, torch_out = custom_passes(graph, params_dict,
graph, params_dict, torch_out = custom_passes(ctx, graph, params_dict,
torch_out)
return graph, params_dict, torch_out

View File

@ -2,7 +2,8 @@
from mmdeploy.utils import get_root_logger
def optimize_onnx(graph, params_dict, torch_out):
def optimize_onnx(ctx, graph, params_dict, torch_out):
"""The optimize callback of the onnx model."""
logger = get_root_logger()
logger.info('Execute onnx optimize passes.')
try:

View File

@ -2,6 +2,12 @@ import os
from setuptools import find_packages, setup
try:
from torch.utils.cpp_extension import BuildExtension
cmd_class = {'build_ext': BuildExtension}
except ModuleNotFoundError:
cmd_class = {}
print('Skip building ext ops due to the absence of torch.')
pwd = os.path.dirname(__file__)
version_file = 'mmdeploy/version.py'
@ -96,6 +102,70 @@ def parse_requirements(fname='requirements.txt', with_version=True):
return packages
def get_extensions():
extensions = []
ext_name = 'mmdeploy.backend.torchscript.ts_optimizer'
import glob
import platform
from torch.utils.cpp_extension import CppExtension
try:
import psutil
num_cpu = len(psutil.Process().cpu_affinity())
cpu_use = max(4, num_cpu - 1)
except (ModuleNotFoundError, AttributeError):
cpu_use = 4
os.environ.setdefault('MAX_JOBS', str(cpu_use))
define_macros = []
# Before PyTorch1.8.0, when compiling CUDA code, `cxx` is a
# required key passed to PyTorch. Even if there is no flag passed
# to cxx, users also need to pass an empty list to PyTorch.
# Since PyTorch1.8.0, it has a default value so users do not need
# to pass an empty list anymore.
# More details at https://github.com/pytorch/pytorch/pull/45956
extra_compile_args = {'cxx': []}
# c++14 is required.
# However, in the windows environment, some standard libraries
# will depend on c++17 or higher. In fact, for the windows
# environment, the compiler will choose the appropriate compiler
# to compile those cpp files, so there is no need to add the
# argument
if platform.system() != 'Windows':
extra_compile_args['cxx'] = ['-std=c++14']
include_dirs = []
op_files = glob.glob(
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/*.cpp'
) + glob.glob(
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/*.cpp'
) + glob.glob(
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/*.cpp')
extension = CppExtension
# c++14 is required.
# However, in the windows environment, some standard libraries
# will depend on c++17 or higher. In fact, for the windows
# environment, the compiler will choose the appropriate compiler
# to compile those cpp files, so there is no need to add the
# argument
if 'nvcc' in extra_compile_args and platform.system() != 'Windows':
extra_compile_args['nvcc'] += ['-std=c++14']
ext_ops = extension(
name=ext_name,
sources=op_files,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args)
extensions.append(ext_ops)
return extensions
if __name__ == '__main__':
setup(
name='mmdeploy',
@ -128,6 +198,6 @@ if __name__ == '__main__':
'build': parse_requirements('requirements/build.txt'),
'optional': parse_requirements('requirements/optional.txt'),
},
ext_modules=[],
cmdclass={},
ext_modules=get_extensions(),
cmdclass=cmd_class,
zip_safe=False)

View File

@ -30,7 +30,7 @@ def test_merge_shape_concate():
except ImportError:
pytest.skip('pass not found.')
def _optimize_onnx(graph, params_dict, torch_out):
def _optimize_onnx(ctx, graph, params_dict, torch_out):
opt_pass(graph)
return graph, params_dict, torch_out
@ -82,7 +82,7 @@ def test_peephole():
except ImportError:
pytest.skip('pass not found.')
def _optimize_onnx(graph, params_dict, torch_out):
def _optimize_onnx(ctx, graph, params_dict, torch_out):
opt_pass(graph)
return graph, params_dict, torch_out
@ -148,7 +148,7 @@ def test_flatten_cls_head():
except ImportError:
pytest.skip('pass not found.')
def _optimize_onnx(graph, params_dict, torch_out):
def _optimize_onnx(ctx, graph, params_dict, torch_out):
opt_pass(graph)
return graph, params_dict, torch_out
@ -199,7 +199,7 @@ def test_fuse_select_assign():
except ImportError:
pytest.skip('pass not found.')
def _optimize_onnx(graph, params_dict, torch_out):
def _optimize_onnx(ctx, graph, params_dict, torch_out):
opt_pass(graph, params_dict)
return graph, params_dict, torch_out
@ -247,7 +247,7 @@ def test_common_subgraph_elimination():
except ImportError:
pytest.skip('pass not found.')
def _optimize_onnx(graph, params_dict, torch_out):
def _optimize_onnx(ctx, graph, params_dict, torch_out):
opt_pass(graph, params_dict)
return graph, params_dict, torch_out

View File

@ -133,6 +133,12 @@ def clear_mmdeploy(mmdeploy_dir: str):
for ncnn_ext_path in ncnn_ext_paths:
os.remove(ncnn_ext_path)
# remove ts_optmizer
ts_optimizer_paths = glob(
osp.join(mmdeploy_dir, 'mmdeploy/backend/torchscript/ts_optimizer.*'))
for ts_optimizer_path in ts_optimizer_paths:
os.remove(ts_optimizer_path)
def build_mmdeploy(cfg, mmdeploy_dir, dist_dir=None):
cmake_flags = cfg.get('cmake_flags', [])