2021-06-17 15:28:23 +08:00
|
|
|
import argparse
|
|
|
|
import logging
|
2021-06-17 15:29:12 +08:00
|
|
|
import os.path as osp
|
2021-06-17 15:28:23 +08:00
|
|
|
|
|
|
|
import mmcv
|
2021-06-17 17:26:32 +08:00
|
|
|
import torch.multiprocessing as mp
|
2021-06-17 15:29:12 +08:00
|
|
|
from torch.multiprocessing import Process, set_start_method
|
2021-06-17 15:28:23 +08:00
|
|
|
|
|
|
|
from mmdeploy.apis import torch2onnx
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
parser = argparse.ArgumentParser(description='Export model to backend.')
|
|
|
|
parser.add_argument('deploy_cfg', help='deploy config path')
|
|
|
|
parser.add_argument('model_cfg', help='model config path')
|
|
|
|
parser.add_argument('checkpoint', help='model checkpoint path')
|
|
|
|
parser.add_argument(
|
|
|
|
'img', help='image used to convert model and test model')
|
|
|
|
parser.add_argument('--work-dir', help='the dir to save logs and models')
|
|
|
|
parser.add_argument(
|
2021-06-17 17:37:08 +08:00
|
|
|
'--device', help='device used for conversion', default='cpu')
|
2021-06-23 13:14:28 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--log-level',
|
|
|
|
help='set log level',
|
|
|
|
default='INFO',
|
|
|
|
choices=list(logging._nameToLevel.keys()))
|
2021-06-17 15:28:23 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
args = parse_args()
|
|
|
|
set_start_method('spawn')
|
|
|
|
|
2021-06-23 13:14:28 +08:00
|
|
|
logger = logging.getLogger()
|
|
|
|
logger.setLevel(args.log_level)
|
|
|
|
|
|
|
|
deploy_cfg_path = args.deploy_cfg
|
|
|
|
model_cfg_path = args.model_cfg
|
|
|
|
checkpoint_path = args.checkpoint
|
|
|
|
|
|
|
|
# load deploy_cfg
|
|
|
|
deploy_cfg = mmcv.Config.fromfile(deploy_cfg_path)
|
2021-07-12 16:26:44 +08:00
|
|
|
if not isinstance(deploy_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
2021-06-23 13:14:28 +08:00
|
|
|
raise TypeError('deploy_cfg must be a filename or Config object, '
|
|
|
|
f'but got {type(deploy_cfg)}')
|
2021-06-17 15:28:23 +08:00
|
|
|
|
|
|
|
# create work_dir if not
|
|
|
|
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
|
|
|
|
|
2021-06-17 17:26:32 +08:00
|
|
|
ret_value = mp.Value('d', 0, lock=False)
|
|
|
|
|
2021-06-23 13:14:28 +08:00
|
|
|
# convert onnx
|
2021-06-17 15:28:23 +08:00
|
|
|
logging.info('start torch2onnx conversion.')
|
2021-06-23 13:14:28 +08:00
|
|
|
onnx_save_file = deploy_cfg['pytorch2onnx']['save_file']
|
2021-06-17 15:28:23 +08:00
|
|
|
process = Process(
|
|
|
|
target=torch2onnx,
|
2021-06-23 13:14:28 +08:00
|
|
|
args=(args.img, args.work_dir, onnx_save_file, deploy_cfg_path,
|
|
|
|
model_cfg_path, checkpoint_path),
|
2021-06-17 17:26:32 +08:00
|
|
|
kwargs=dict(device=args.device, ret_value=ret_value))
|
2021-06-17 15:28:23 +08:00
|
|
|
process.start()
|
|
|
|
process.join()
|
|
|
|
|
2021-06-17 17:26:32 +08:00
|
|
|
if ret_value.value != 0:
|
|
|
|
logging.error('torch2onnx failed.')
|
|
|
|
exit()
|
|
|
|
else:
|
|
|
|
logging.info('torch2onnx success.')
|
|
|
|
|
2021-06-23 13:14:28 +08:00
|
|
|
# convert backend
|
2021-07-01 11:42:07 +08:00
|
|
|
onnx_pathes = [osp.join(args.work_dir, onnx_save_file)]
|
|
|
|
|
2021-06-23 13:14:28 +08:00
|
|
|
backend = deploy_cfg.get('backend', 'default')
|
|
|
|
if backend == 'tensorrt':
|
2021-07-12 16:26:44 +08:00
|
|
|
assert hasattr(deploy_cfg, 'tensorrt_params')
|
|
|
|
tensorrt_params = deploy_cfg['tensorrt_params']
|
|
|
|
model_params = tensorrt_params.get('model_params', [])
|
2021-07-01 11:42:07 +08:00
|
|
|
assert len(model_params) == len(onnx_pathes)
|
|
|
|
|
2021-06-23 13:14:28 +08:00
|
|
|
logging.info('start onnx2tensorrt conversion.')
|
|
|
|
from mmdeploy.apis.tensorrt import onnx2tensorrt
|
2021-07-01 11:42:07 +08:00
|
|
|
for model_id, model_param, onnx_path in zip(
|
|
|
|
range(len(onnx_pathes)), model_params, onnx_pathes):
|
|
|
|
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
|
|
|
|
save_file = model_param.get('save_file', onnx_name + '.engine')
|
2021-06-23 13:14:28 +08:00
|
|
|
process = Process(
|
|
|
|
target=onnx2tensorrt,
|
2021-07-01 11:42:07 +08:00
|
|
|
args=(args.work_dir, save_file, model_id, deploy_cfg_path,
|
2021-06-23 13:14:28 +08:00
|
|
|
onnx_path),
|
|
|
|
kwargs=dict(device=args.device, ret_value=ret_value))
|
|
|
|
process.start()
|
|
|
|
process.join()
|
|
|
|
|
|
|
|
if ret_value.value != 0:
|
|
|
|
logging.error('onnx2tensorrt failed.')
|
|
|
|
exit()
|
|
|
|
else:
|
|
|
|
logging.info('onnx2tensorrt success.')
|
|
|
|
|
|
|
|
logging.info('All process success.')
|
|
|
|
|
2021-06-17 15:28:23 +08:00
|
|
|
|
2021-06-17 15:29:12 +08:00
|
|
|
if __name__ == '__main__':
|
2021-06-17 15:28:23 +08:00
|
|
|
main()
|