mmdeploy/tools/deploy.py
AllentDan d0bb497b0c
[Fix] fix ncnn apis, add onnx2ncnn unit tests, and rename test_mmseg/test_utils.py to avoid conflicts (#233)
* refactor ncnn convert

* refine docstring

* add onnx2ncnn test

* resolve comments and rename test_utils.py in test_mmseg

Co-authored-by: maningsheng <mnsheng@yeah.net>
2021-11-26 11:43:27 +08:00

304 lines
11 KiB
Python

import argparse
import logging
import os.path as osp
import sys
import traceback
from functools import partial
import mmcv
import torch.multiprocessing as mp
from torch.multiprocessing import Process, set_start_method
from mmdeploy.apis import (create_calib_table, extract_model,
get_predefined_partition_cfg, torch2onnx,
visualize_model)
from mmdeploy.utils import (Backend, get_backend, get_calib_filename,
get_model_inputs, get_onnx_config,
get_partition_config, load_config)
from mmdeploy.utils.export_info import dump_info
def parse_args():
parser = argparse.ArgumentParser(description='Export model to backends.')
parser.add_argument('deploy_cfg', help='deploy config path')
parser.add_argument('model_cfg', help='model config path')
parser.add_argument('checkpoint', help='model checkpoint path')
parser.add_argument('img', help='image used to convert model model')
parser.add_argument(
'--test-img', default=None, help='image used to test model')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--calib-dataset-cfg',
help='dataset config path used to calibrate.',
default=None)
parser.add_argument(
'--device', help='device used for conversion', default='cpu')
parser.add_argument(
'--log-level',
help='set log level',
default='INFO',
choices=list(logging._nameToLevel.keys()))
parser.add_argument(
'--show', action='store_true', help='Show detection outputs')
parser.add_argument(
'--dump-info', action='store_true', help='Output information for SDK')
args = parser.parse_args()
return args
def target_wrapper(target, log_level, ret_value, *args, **kwargs):
logger = logging.getLogger()
logging.basicConfig(
format='%(asctime)s,%(name)s %(levelname)-8s'
' [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d:%H:%M:%S')
logger.level
logger.setLevel(log_level)
if ret_value is not None:
ret_value.value = -1
try:
result = target(*args, **kwargs)
if ret_value is not None:
ret_value.value = 0
return result
except Exception as e:
logging.error(e)
traceback.print_exc(file=sys.stdout)
def create_process(name, target, args, kwargs, ret_value=None):
logging.info(f'{name} start.')
log_level = logging.getLogger().level
wrap_func = partial(target_wrapper, target, log_level, ret_value)
process = Process(target=wrap_func, args=args, kwargs=kwargs)
process.start()
process.join()
if ret_value is not None:
if ret_value.value != 0:
logging.error(f'{name} failed.')
exit()
else:
logging.info(f'{name} success.')
def main():
args = parse_args()
set_start_method('spawn')
logging.basicConfig(
format='%(asctime)s,%(name)s %(levelname)-8s'
' [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d:%H:%M:%S')
logger = logging.getLogger()
logger.setLevel(args.log_level)
deploy_cfg_path = args.deploy_cfg
model_cfg_path = args.model_cfg
checkpoint_path = args.checkpoint
# load deploy_cfg
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
# create work_dir if not
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
if args.dump_info:
dump_info(deploy_cfg, model_cfg, args.work_dir)
ret_value = mp.Value('d', 0, lock=False)
# convert onnx
onnx_save_file = get_onnx_config(deploy_cfg)['save_file']
create_process(
'torch2onnx',
target=torch2onnx,
args=(args.img, args.work_dir, onnx_save_file, deploy_cfg_path,
model_cfg_path, checkpoint_path),
kwargs=dict(device=args.device),
ret_value=ret_value)
# convert backend
onnx_files = [osp.join(args.work_dir, onnx_save_file)]
# partition model
partition_cfgs = get_partition_config(deploy_cfg)
if partition_cfgs is not None:
if 'partition_cfg' in partition_cfgs:
partition_cfgs = partition_cfgs.get('partition_cfg', None)
else:
assert 'type' in partition_cfgs
partition_cfgs = get_predefined_partition_cfg(
deploy_cfg, partition_cfgs['type'])
origin_onnx_file = onnx_files[0]
onnx_files = []
for partition_cfg in partition_cfgs:
save_file = partition_cfg['save_file']
save_path = osp.join(args.work_dir, save_file)
start = partition_cfg['start']
end = partition_cfg['end']
dynamic_axes = partition_cfg.get('dynamic_axes', None)
create_process(
f'partition model {save_file} with start: {start}, end: {end}',
extract_model,
args=(origin_onnx_file, start, end),
kwargs=dict(dynamic_axes=dynamic_axes, save_file=save_path),
ret_value=ret_value)
onnx_files.append(save_path)
# calib data
calib_filename = get_calib_filename(deploy_cfg)
if calib_filename is not None:
calib_path = osp.join(args.work_dir, calib_filename)
create_process(
'calibration',
create_calib_table,
args=(calib_path, deploy_cfg_path, model_cfg_path,
checkpoint_path),
kwargs=dict(
dataset_cfg=args.calib_dataset_cfg,
dataset_type='val',
device=args.device),
ret_value=ret_value)
backend_files = onnx_files
# convert backend
backend = get_backend(deploy_cfg, 'default')
if backend == Backend.TENSORRT:
model_params = get_model_inputs(deploy_cfg)
assert len(model_params) == len(onnx_files)
from mmdeploy.apis.tensorrt import is_available as trt_is_available
from mmdeploy.apis.tensorrt import onnx2tensorrt
assert trt_is_available(
), 'TensorRT is not available,' \
+ ' please install TensorRT and build TensorRT custom ops first.'
backend_files = []
for model_id, model_param, onnx_path in zip(
range(len(onnx_files)), model_params, onnx_files):
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
save_file = model_param.get('save_file', onnx_name + '.engine')
partition_type = 'end2end' if partition_cfgs is None \
else onnx_name
create_process(
f'onnx2tensorrt of {onnx_path}',
target=onnx2tensorrt,
args=(args.work_dir, save_file, model_id, deploy_cfg_path,
onnx_path),
kwargs=dict(device=args.device, partition_type=partition_type),
ret_value=ret_value)
backend_files.append(osp.join(args.work_dir, save_file))
elif backend == Backend.NCNN:
from mmdeploy.apis.ncnn import is_available as is_available_ncnn
if not is_available_ncnn():
logging.error('ncnn support is not available.')
exit(-1)
from mmdeploy.apis.ncnn import onnx2ncnn, get_output_model_file
backend_files = []
for onnx_path in onnx_files:
model_param_path, model_bin_path = get_output_model_file(
onnx_path, args.work_dir)
create_process(
f'onnx2ncnn with {onnx_path}',
target=onnx2ncnn,
args=(onnx_path, model_param_path, model_bin_path),
kwargs=dict(),
ret_value=ret_value)
backend_files += [model_param_path, model_bin_path]
elif backend == Backend.OPENVINO:
from mmdeploy.apis.openvino import \
is_available as is_available_openvino
assert is_available_openvino(), \
'OpenVINO is not available, please install OpenVINO first.'
from mmdeploy.apis.openvino import (onnx2openvino,
get_output_model_file,
get_input_shape_from_cfg)
openvino_files = []
for onnx_path in onnx_files:
model_xml_path = get_output_model_file(onnx_path, args.work_dir)
input_name = deploy_cfg.onnx_config.input_names
input_shape = [get_input_shape_from_cfg(model_cfg)]
input_info = dict(zip(input_name, input_shape))
output_names = deploy_cfg.onnx_config.output_names
create_process(
f'onnx2openvino with {onnx_path}',
target=onnx2openvino,
args=(input_info, output_names, onnx_path, args.work_dir),
kwargs=dict(),
ret_value=ret_value)
openvino_files.append(model_xml_path)
backend_files = openvino_files
elif backend == Backend.PPL:
from mmdeploy.apis.ppl import \
is_available as is_available_ppl
assert is_available_ppl(), \
'PPL is not available, please install PPL first.'
from mmdeploy.apis.ppl import onnx2ppl
ppl_files = []
for onnx_path in onnx_files:
algo_file = onnx_path.replace('.onnx', '.json')
model_inputs = get_model_inputs(deploy_cfg)
assert 'opt_shape' in model_inputs, 'expect opt_shape '
'in deploy config for ppl'
# PPL accepts only 1 input shape for optimization,
# may get changed in the future
input_shapes = [model_inputs.opt_shape]
create_process(
f'onnx2ppl with {onnx_path}',
target=onnx2ppl,
args=(algo_file, onnx_path),
kwargs=dict(device=args.device, input_shapes=input_shapes),
ret_value=ret_value)
ppl_files += [onnx_path, algo_file]
backend_files = ppl_files
if args.test_img is None:
args.test_img = args.img
# visualize model of the backend
create_process(
f'visualize {backend.value} model',
target=visualize_model,
args=(model_cfg_path, deploy_cfg_path, backend_files, args.test_img,
args.device),
kwargs=dict(
backend=backend,
output_file=osp.join(args.work_dir, f'output_{backend.value}.jpg'),
show_result=args.show),
ret_value=ret_value)
# visualize pytorch model
create_process(
'visualize pytorch model',
target=visualize_model,
args=(model_cfg_path, deploy_cfg_path, [checkpoint_path],
args.test_img, args.device),
kwargs=dict(
backend=Backend.PYTORCH,
output_file=osp.join(args.work_dir, 'output_pytorch.jpg'),
show_result=args.show),
ret_value=ret_value)
logging.info('All process success.')
if __name__ == '__main__':
main()