mmdeploy/tools/deploy.py
2021-06-23 13:14:28 +08:00

97 lines
2.9 KiB
Python

import argparse
import logging
import os.path as osp
import mmcv
import torch.multiprocessing as mp
from torch.multiprocessing import Process, set_start_method
from mmdeploy.apis import torch2onnx
def parse_args():
parser = argparse.ArgumentParser(description='Export model to backend.')
parser.add_argument('deploy_cfg', help='deploy config path')
parser.add_argument('model_cfg', help='model config path')
parser.add_argument('checkpoint', help='model checkpoint path')
parser.add_argument(
'img', help='image used to convert model and test model')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--device', help='device used for conversion', default='cpu')
parser.add_argument(
'--log-level',
help='set log level',
default='INFO',
choices=list(logging._nameToLevel.keys()))
args = parser.parse_args()
return args
def main():
args = parse_args()
set_start_method('spawn')
logger = logging.getLogger()
logger.setLevel(args.log_level)
deploy_cfg_path = args.deploy_cfg
model_cfg_path = args.model_cfg
checkpoint_path = args.checkpoint
# load deploy_cfg
deploy_cfg = mmcv.Config.fromfile(deploy_cfg_path)
if not isinstance(deploy_cfg, mmcv.Config):
raise TypeError('deploy_cfg must be a filename or Config object, '
f'but got {type(deploy_cfg)}')
# create work_dir if not
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
ret_value = mp.Value('d', 0, lock=False)
# convert onnx
logging.info('start torch2onnx conversion.')
onnx_save_file = deploy_cfg['pytorch2onnx']['save_file']
process = Process(
target=torch2onnx,
args=(args.img, args.work_dir, onnx_save_file, deploy_cfg_path,
model_cfg_path, checkpoint_path),
kwargs=dict(device=args.device, ret_value=ret_value))
process.start()
process.join()
if ret_value.value != 0:
logging.error('torch2onnx failed.')
exit()
else:
logging.info('torch2onnx success.')
# convert backend
backend = deploy_cfg.get('backend', 'default')
onnx_paths = [osp.join(args.work_dir, onnx_save_file)]
if backend == 'tensorrt':
logging.info('start onnx2tensorrt conversion.')
from mmdeploy.apis.tensorrt import onnx2tensorrt
for onnx_path in onnx_paths:
process = Process(
target=onnx2tensorrt,
args=(args.work_dir, onnx_save_file, deploy_cfg_path,
onnx_path),
kwargs=dict(device=args.device, ret_value=ret_value))
process.start()
process.join()
if ret_value.value != 0:
logging.error('onnx2tensorrt failed.')
exit()
else:
logging.info('onnx2tensorrt success.')
logging.info('All process success.')
if __name__ == '__main__':
main()