[Enhancement] Install Optimizer by setuptools (#690)
* Add fuse select assign pass * move code to csrc * add config flag * Add fuse select assign pass * Add CSE for ONNX * remove useless code * Install optimizer by setup tools * fix commentpull/817/head
parent
36c35b6e88
commit
5b31d7a60d
|
@ -5,3 +5,6 @@ include mmdeploy/backend/ncnn/*.pyd
|
|||
include mmdeploy/lib/*.so
|
||||
include mmdeploy/lib/*.dll
|
||||
include mmdeploy/lib/*.pyd
|
||||
include mmdeploy/backend/torchscript/*.so
|
||||
include mmdeploy/backend/torchscript/*.dll
|
||||
include mmdeploy/backend/torchscript/*.pyd
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
add_subdirectory(ops)
|
||||
add_subdirectory(optimizer)
|
||||
|
|
|
@ -15,7 +15,7 @@ def model_to_graph__custom_optimizer(ctx, *args, **kwargs):
|
|||
assert isinstance(
|
||||
custom_passes, Callable
|
||||
), f'Expect a callable onnx_custom_passes, get {type(custom_passes)}.'
|
||||
graph, params_dict, torch_out = custom_passes(graph, params_dict,
|
||||
graph, params_dict, torch_out = custom_passes(ctx, graph, params_dict,
|
||||
torch_out)
|
||||
|
||||
return graph, params_dict, torch_out
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
from mmdeploy.utils import get_root_logger
|
||||
|
||||
|
||||
def optimize_onnx(graph, params_dict, torch_out):
|
||||
def optimize_onnx(ctx, graph, params_dict, torch_out):
|
||||
"""The optimize callback of the onnx model."""
|
||||
logger = get_root_logger()
|
||||
logger.info('Execute onnx optimize passes.')
|
||||
try:
|
||||
|
|
74
setup.py
74
setup.py
|
@ -2,6 +2,12 @@ import os
|
|||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
try:
|
||||
from torch.utils.cpp_extension import BuildExtension
|
||||
cmd_class = {'build_ext': BuildExtension}
|
||||
except ModuleNotFoundError:
|
||||
cmd_class = {}
|
||||
print('Skip building ext ops due to the absence of torch.')
|
||||
pwd = os.path.dirname(__file__)
|
||||
version_file = 'mmdeploy/version.py'
|
||||
|
||||
|
@ -96,6 +102,70 @@ def parse_requirements(fname='requirements.txt', with_version=True):
|
|||
return packages
|
||||
|
||||
|
||||
def get_extensions():
|
||||
extensions = []
|
||||
ext_name = 'mmdeploy.backend.torchscript.ts_optimizer'
|
||||
import glob
|
||||
import platform
|
||||
|
||||
from torch.utils.cpp_extension import CppExtension
|
||||
|
||||
try:
|
||||
import psutil
|
||||
num_cpu = len(psutil.Process().cpu_affinity())
|
||||
cpu_use = max(4, num_cpu - 1)
|
||||
except (ModuleNotFoundError, AttributeError):
|
||||
cpu_use = 4
|
||||
|
||||
os.environ.setdefault('MAX_JOBS', str(cpu_use))
|
||||
define_macros = []
|
||||
|
||||
# Before PyTorch1.8.0, when compiling CUDA code, `cxx` is a
|
||||
# required key passed to PyTorch. Even if there is no flag passed
|
||||
# to cxx, users also need to pass an empty list to PyTorch.
|
||||
# Since PyTorch1.8.0, it has a default value so users do not need
|
||||
# to pass an empty list anymore.
|
||||
# More details at https://github.com/pytorch/pytorch/pull/45956
|
||||
extra_compile_args = {'cxx': []}
|
||||
|
||||
# c++14 is required.
|
||||
# However, in the windows environment, some standard libraries
|
||||
# will depend on c++17 or higher. In fact, for the windows
|
||||
# environment, the compiler will choose the appropriate compiler
|
||||
# to compile those cpp files, so there is no need to add the
|
||||
# argument
|
||||
if platform.system() != 'Windows':
|
||||
extra_compile_args['cxx'] = ['-std=c++14']
|
||||
|
||||
include_dirs = []
|
||||
|
||||
op_files = glob.glob(
|
||||
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/*.cpp'
|
||||
) + glob.glob(
|
||||
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/*.cpp'
|
||||
) + glob.glob(
|
||||
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/*.cpp')
|
||||
extension = CppExtension
|
||||
|
||||
# c++14 is required.
|
||||
# However, in the windows environment, some standard libraries
|
||||
# will depend on c++17 or higher. In fact, for the windows
|
||||
# environment, the compiler will choose the appropriate compiler
|
||||
# to compile those cpp files, so there is no need to add the
|
||||
# argument
|
||||
if 'nvcc' in extra_compile_args and platform.system() != 'Windows':
|
||||
extra_compile_args['nvcc'] += ['-std=c++14']
|
||||
|
||||
ext_ops = extension(
|
||||
name=ext_name,
|
||||
sources=op_files,
|
||||
include_dirs=include_dirs,
|
||||
define_macros=define_macros,
|
||||
extra_compile_args=extra_compile_args)
|
||||
extensions.append(ext_ops)
|
||||
return extensions
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
setup(
|
||||
name='mmdeploy',
|
||||
|
@ -128,6 +198,6 @@ if __name__ == '__main__':
|
|||
'build': parse_requirements('requirements/build.txt'),
|
||||
'optional': parse_requirements('requirements/optional.txt'),
|
||||
},
|
||||
ext_modules=[],
|
||||
cmdclass={},
|
||||
ext_modules=get_extensions(),
|
||||
cmdclass=cmd_class,
|
||||
zip_safe=False)
|
||||
|
|
|
@ -30,7 +30,7 @@ def test_merge_shape_concate():
|
|||
except ImportError:
|
||||
pytest.skip('pass not found.')
|
||||
|
||||
def _optimize_onnx(graph, params_dict, torch_out):
|
||||
def _optimize_onnx(ctx, graph, params_dict, torch_out):
|
||||
opt_pass(graph)
|
||||
return graph, params_dict, torch_out
|
||||
|
||||
|
@ -82,7 +82,7 @@ def test_peephole():
|
|||
except ImportError:
|
||||
pytest.skip('pass not found.')
|
||||
|
||||
def _optimize_onnx(graph, params_dict, torch_out):
|
||||
def _optimize_onnx(ctx, graph, params_dict, torch_out):
|
||||
opt_pass(graph)
|
||||
return graph, params_dict, torch_out
|
||||
|
||||
|
@ -148,7 +148,7 @@ def test_flatten_cls_head():
|
|||
except ImportError:
|
||||
pytest.skip('pass not found.')
|
||||
|
||||
def _optimize_onnx(graph, params_dict, torch_out):
|
||||
def _optimize_onnx(ctx, graph, params_dict, torch_out):
|
||||
opt_pass(graph)
|
||||
return graph, params_dict, torch_out
|
||||
|
||||
|
@ -199,7 +199,7 @@ def test_fuse_select_assign():
|
|||
except ImportError:
|
||||
pytest.skip('pass not found.')
|
||||
|
||||
def _optimize_onnx(graph, params_dict, torch_out):
|
||||
def _optimize_onnx(ctx, graph, params_dict, torch_out):
|
||||
opt_pass(graph, params_dict)
|
||||
return graph, params_dict, torch_out
|
||||
|
||||
|
@ -247,7 +247,7 @@ def test_common_subgraph_elimination():
|
|||
except ImportError:
|
||||
pytest.skip('pass not found.')
|
||||
|
||||
def _optimize_onnx(graph, params_dict, torch_out):
|
||||
def _optimize_onnx(ctx, graph, params_dict, torch_out):
|
||||
opt_pass(graph, params_dict)
|
||||
return graph, params_dict, torch_out
|
||||
|
||||
|
|
|
@ -133,6 +133,12 @@ def clear_mmdeploy(mmdeploy_dir: str):
|
|||
for ncnn_ext_path in ncnn_ext_paths:
|
||||
os.remove(ncnn_ext_path)
|
||||
|
||||
# remove ts_optmizer
|
||||
ts_optimizer_paths = glob(
|
||||
osp.join(mmdeploy_dir, 'mmdeploy/backend/torchscript/ts_optimizer.*'))
|
||||
for ts_optimizer_path in ts_optimizer_paths:
|
||||
os.remove(ts_optimizer_path)
|
||||
|
||||
|
||||
def build_mmdeploy(cfg, mmdeploy_dir, dist_dir=None):
|
||||
cmake_flags = cfg.get('cmake_flags', [])
|
||||
|
|
Loading…
Reference in New Issue