307 lines
11 KiB
Python
307 lines
11 KiB
Python
import argparse
|
|
import logging
|
|
import os.path as osp
|
|
import subprocess
|
|
import sys
|
|
import traceback
|
|
from functools import partial
|
|
|
|
import mmcv
|
|
import torch.multiprocessing as mp
|
|
from torch.multiprocessing import Process, set_start_method
|
|
|
|
from mmdeploy.apis import (create_calib_table, extract_model, inference_model,
|
|
torch2onnx)
|
|
from mmdeploy.apis.utils import get_partition_cfg as parse_partition_cfg
|
|
from mmdeploy.utils import (Backend, get_backend, get_calib_filename,
|
|
get_codebase, get_model_inputs, get_onnx_config,
|
|
get_partition_config, load_config)
|
|
from mmdeploy.utils.export_info import dump_info
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Export model to backends.')
|
|
parser.add_argument('deploy_cfg', help='deploy config path')
|
|
parser.add_argument('model_cfg', help='model config path')
|
|
parser.add_argument('checkpoint', help='model checkpoint path')
|
|
parser.add_argument('img', help='image used to convert model model')
|
|
parser.add_argument(
|
|
'--test-img', default=None, help='image used to test model')
|
|
parser.add_argument('--work-dir', help='the dir to save logs and models')
|
|
parser.add_argument(
|
|
'--calib-dataset-cfg',
|
|
help='dataset config path used to calibrate.',
|
|
default=None)
|
|
parser.add_argument(
|
|
'--device', help='device used for conversion', default='cpu')
|
|
parser.add_argument(
|
|
'--log-level',
|
|
help='set log level',
|
|
default='INFO',
|
|
choices=list(logging._nameToLevel.keys()))
|
|
parser.add_argument(
|
|
'--show', action='store_true', help='Show detection outputs')
|
|
parser.add_argument(
|
|
'--dump-info', action='store_true', help='Output information for SDK')
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def target_wrapper(target, log_level, ret_value, *args, **kwargs):
|
|
logger = logging.getLogger()
|
|
logging.basicConfig(
|
|
format='%(asctime)s,%(msecs)d %(levelname)-8s'
|
|
' [%(filename)s:%(lineno)d] %(message)s',
|
|
datefmt='%Y-%m-%d:%H:%M:%S')
|
|
logger.level
|
|
logger.setLevel(log_level)
|
|
if ret_value is not None:
|
|
ret_value.value = -1
|
|
try:
|
|
result = target(*args, **kwargs)
|
|
if ret_value is not None:
|
|
ret_value.value = 0
|
|
return result
|
|
except Exception as e:
|
|
logging.error(e)
|
|
traceback.print_exc(file=sys.stdout)
|
|
|
|
|
|
def create_process(name, target, args, kwargs, ret_value=None):
|
|
logging.info(f'{name} start.')
|
|
log_level = logging.getLogger().level
|
|
|
|
wrap_func = partial(target_wrapper, target, log_level, ret_value)
|
|
|
|
process = Process(target=wrap_func, args=args, kwargs=kwargs)
|
|
process.start()
|
|
process.join()
|
|
|
|
if ret_value is not None:
|
|
if ret_value.value != 0:
|
|
logging.error(f'{name} failed.')
|
|
exit()
|
|
else:
|
|
logging.info(f'{name} success.')
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
set_start_method('spawn')
|
|
logging.basicConfig(
|
|
format='%(asctime)s,%(msecs)d %(levelname)-8s'
|
|
' [%(filename)s:%(lineno)d] %(message)s',
|
|
datefmt='%Y-%m-%d:%H:%M:%S')
|
|
logger = logging.getLogger()
|
|
logger.setLevel(args.log_level)
|
|
|
|
deploy_cfg_path = args.deploy_cfg
|
|
model_cfg_path = args.model_cfg
|
|
checkpoint_path = args.checkpoint
|
|
|
|
# load deploy_cfg
|
|
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
|
|
|
|
# create work_dir if not
|
|
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
|
|
|
|
if args.dump_info:
|
|
dump_info(deploy_cfg, model_cfg, args.work_dir)
|
|
|
|
ret_value = mp.Value('d', 0, lock=False)
|
|
|
|
# convert onnx
|
|
onnx_save_file = get_onnx_config(deploy_cfg)['save_file']
|
|
create_process(
|
|
'torch2onnx',
|
|
target=torch2onnx,
|
|
args=(args.img, args.work_dir, onnx_save_file, deploy_cfg_path,
|
|
model_cfg_path, checkpoint_path),
|
|
kwargs=dict(device=args.device),
|
|
ret_value=ret_value)
|
|
|
|
# convert backend
|
|
onnx_files = [osp.join(args.work_dir, onnx_save_file)]
|
|
|
|
# partition model
|
|
partition_cfgs = get_partition_config(deploy_cfg)
|
|
|
|
if partition_cfgs is not None:
|
|
|
|
if 'partition_cfg' in partition_cfgs:
|
|
partition_cfgs = partition_cfgs.get('partition_cfg', None)
|
|
else:
|
|
assert 'type' in partition_cfgs
|
|
partition_cfgs = parse_partition_cfg(
|
|
get_codebase(deploy_cfg), partition_cfgs['type'])
|
|
|
|
origin_onnx_file = onnx_files[0]
|
|
onnx_files = []
|
|
for partition_cfg in partition_cfgs:
|
|
save_file = partition_cfg['save_file']
|
|
save_path = osp.join(args.work_dir, save_file)
|
|
start = partition_cfg['start']
|
|
end = partition_cfg['end']
|
|
dynamic_axes = partition_cfg.get('dynamic_axes', None)
|
|
|
|
create_process(
|
|
f'partition model {save_file} with start: {start}, end: {end}',
|
|
extract_model,
|
|
args=(origin_onnx_file, start, end),
|
|
kwargs=dict(dynamic_axes=dynamic_axes, save_file=save_path),
|
|
ret_value=ret_value)
|
|
|
|
onnx_files.append(save_path)
|
|
|
|
# calib data
|
|
calib_filename = get_calib_filename(deploy_cfg)
|
|
if calib_filename is not None:
|
|
calib_path = osp.join(args.work_dir, calib_filename)
|
|
|
|
create_process(
|
|
'calibration',
|
|
create_calib_table,
|
|
args=(calib_path, deploy_cfg_path, model_cfg_path,
|
|
checkpoint_path),
|
|
kwargs=dict(
|
|
dataset_cfg=args.calib_dataset_cfg,
|
|
dataset_type='val',
|
|
device=args.device),
|
|
ret_value=ret_value)
|
|
|
|
backend_files = onnx_files
|
|
# convert backend
|
|
backend = get_backend(deploy_cfg, 'default')
|
|
if backend == Backend.TENSORRT:
|
|
model_params = get_model_inputs(deploy_cfg)
|
|
assert len(model_params) == len(onnx_files)
|
|
|
|
from mmdeploy.apis.tensorrt import is_available as trt_is_available
|
|
from mmdeploy.apis.tensorrt import onnx2tensorrt
|
|
assert trt_is_available(
|
|
), 'TensorRT is not available,' \
|
|
+ ' please install TensorRT and build TensorRT custom ops first.'
|
|
backend_files = []
|
|
for model_id, model_param, onnx_path in zip(
|
|
range(len(onnx_files)), model_params, onnx_files):
|
|
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
|
|
save_file = model_param.get('save_file', onnx_name + '.engine')
|
|
|
|
partition_type = 'end2end' if partition_cfgs is None \
|
|
else onnx_name
|
|
create_process(
|
|
f'onnx2tensorrt of {onnx_path}',
|
|
target=onnx2tensorrt,
|
|
args=(args.work_dir, save_file, model_id, deploy_cfg_path,
|
|
onnx_path),
|
|
kwargs=dict(device=args.device, partition_type=partition_type),
|
|
ret_value=ret_value)
|
|
|
|
backend_files.append(osp.join(args.work_dir, save_file))
|
|
|
|
elif backend == Backend.NCNN:
|
|
from mmdeploy.apis.ncnn import get_onnx2ncnn_path
|
|
from mmdeploy.apis.ncnn import is_available as is_available_ncnn
|
|
|
|
if not is_available_ncnn():
|
|
logging.error('ncnn support is not available.')
|
|
exit(-1)
|
|
|
|
onnx2ncnn_path = get_onnx2ncnn_path()
|
|
|
|
backend_files = []
|
|
for onnx_path in onnx_files:
|
|
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
|
|
save_param = onnx_name + '.param'
|
|
save_bin = onnx_name + '.bin'
|
|
|
|
save_param = osp.join(args.work_dir, save_param)
|
|
save_bin = osp.join(args.work_dir, save_bin)
|
|
|
|
subprocess.call([onnx2ncnn_path, onnx_path, save_param, save_bin])
|
|
|
|
backend_files += [save_param, save_bin]
|
|
|
|
elif backend == Backend.OPENVINO:
|
|
from mmdeploy.apis.openvino import \
|
|
is_available as is_available_openvino
|
|
assert is_available_openvino(), \
|
|
'OpenVINO is not available, please install OpenVINO first.'
|
|
|
|
from mmdeploy.apis.openvino import (onnx2openvino,
|
|
get_output_model_file,
|
|
get_input_shape_from_cfg)
|
|
openvino_files = []
|
|
for onnx_path in onnx_files:
|
|
model_xml_path = get_output_model_file(onnx_path, args.work_dir)
|
|
input_name = deploy_cfg.onnx_config.input_names
|
|
input_shape = [get_input_shape_from_cfg(model_cfg)]
|
|
input_info = dict(zip(input_name, input_shape))
|
|
output_names = deploy_cfg.onnx_config.output_names
|
|
create_process(
|
|
f'onnx2openvino with {onnx_path}',
|
|
target=onnx2openvino,
|
|
args=(input_info, output_names, onnx_path, args.work_dir),
|
|
kwargs=dict(),
|
|
ret_value=ret_value)
|
|
openvino_files.append(model_xml_path)
|
|
backend_files = openvino_files
|
|
|
|
elif backend == Backend.PPL:
|
|
from mmdeploy.apis.ppl import \
|
|
is_available as is_available_ppl
|
|
assert is_available_ppl(), \
|
|
'PPL is not available, please install PPL first.'
|
|
|
|
from mmdeploy.apis.ppl import onnx2ppl
|
|
ppl_files = []
|
|
for onnx_path in onnx_files:
|
|
algo_file = onnx_path.replace('.onnx', '.json')
|
|
model_inputs = get_model_inputs(deploy_cfg)
|
|
assert 'opt_shape' in model_inputs, 'expect opt_shape '
|
|
'in deploy config for ppl'
|
|
# PPL accepts only 1 input shape for optimization,
|
|
# may get changed in the future
|
|
input_shapes = [model_inputs.opt_shape]
|
|
create_process(
|
|
f'onnx2ppl with {onnx_path}',
|
|
target=onnx2ppl,
|
|
args=(algo_file, onnx_path),
|
|
kwargs=dict(device=args.device, input_shapes=input_shapes),
|
|
ret_value=ret_value)
|
|
ppl_files += [onnx_path, algo_file]
|
|
backend_files = ppl_files
|
|
|
|
if args.test_img is None:
|
|
args.test_img = args.img
|
|
# visualize model of the backend
|
|
create_process(
|
|
f'visualize {backend.value} model',
|
|
target=inference_model,
|
|
args=(model_cfg_path, deploy_cfg_path, backend_files, args.test_img,
|
|
args.device),
|
|
kwargs=dict(
|
|
backend=backend,
|
|
output_file=osp.join(args.work_dir, f'output_{backend.value}.jpg'),
|
|
show_result=args.show),
|
|
ret_value=ret_value)
|
|
|
|
# visualize pytorch model
|
|
create_process(
|
|
'visualize pytorch model',
|
|
target=inference_model,
|
|
args=(model_cfg_path, deploy_cfg_path, [checkpoint_path],
|
|
args.test_img, args.device),
|
|
kwargs=dict(
|
|
backend=Backend.PYTORCH,
|
|
output_file=osp.join(args.work_dir, 'output_pytorch.jpg'),
|
|
show_result=args.show),
|
|
ret_value=ret_value)
|
|
|
|
logging.info('All process success.')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|