336 lines
11 KiB
Python
336 lines
11 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import os.path as osp
|
|
from functools import partial
|
|
|
|
import mmengine
|
|
import torch.multiprocessing as mp
|
|
from torch.multiprocessing import Process, set_start_method
|
|
|
|
from mmdeploy.apis import (create_calib_input_data, extract_model,
|
|
get_predefined_partition_cfg, torch2onnx,
|
|
torch2torchscript, visualize_model)
|
|
from mmdeploy.apis.core import PIPELINE_MANAGER
|
|
from mmdeploy.apis.utils import to_backend
|
|
from mmdeploy.backend.sdk.export_info import export2SDK
|
|
from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename,
|
|
get_ir_config, get_partition_config,
|
|
get_root_logger, load_config, target_wrapper)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Export model to backends.')
|
|
parser.add_argument('deploy_cfg', help='deploy config path')
|
|
parser.add_argument('model_cfg', help='model config path')
|
|
parser.add_argument('checkpoint', help='model checkpoint path')
|
|
parser.add_argument('img', help='image used to convert model model')
|
|
parser.add_argument(
|
|
'--test-img',
|
|
default=None,
|
|
type=str,
|
|
nargs='+',
|
|
help='image used to test model')
|
|
parser.add_argument(
|
|
'--work-dir',
|
|
default=os.getcwd(),
|
|
help='the dir to save logs and models')
|
|
parser.add_argument(
|
|
'--calib-dataset-cfg',
|
|
help='dataset config path used to calibrate in int8 mode. If not \
|
|
specified, it will use "val" dataset in model config instead.',
|
|
default=None)
|
|
parser.add_argument(
|
|
'--device', help='device used for conversion', default='cpu')
|
|
parser.add_argument(
|
|
'--log-level',
|
|
help='set log level',
|
|
default='INFO',
|
|
choices=list(logging._nameToLevel.keys()))
|
|
parser.add_argument(
|
|
'--show', action='store_true', help='Show detection outputs')
|
|
parser.add_argument(
|
|
'--dump-info', action='store_true', help='Output information for SDK')
|
|
parser.add_argument(
|
|
'--quant-image-dir',
|
|
default=None,
|
|
help='Image directory for quantize model.')
|
|
parser.add_argument(
|
|
'--quant', action='store_true', help='Quantize model to low bit.')
|
|
parser.add_argument(
|
|
'--uri',
|
|
default='192.168.1.1:60000',
|
|
help='Remote ipv4:port or ipv6:port for inference on edge device.')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def create_process(name, target, args, kwargs, ret_value=None):
|
|
logger = get_root_logger()
|
|
logger.info(f'{name} start.')
|
|
log_level = logger.level
|
|
|
|
wrap_func = partial(target_wrapper, target, log_level, ret_value)
|
|
|
|
process = Process(target=wrap_func, args=args, kwargs=kwargs)
|
|
process.start()
|
|
process.join()
|
|
|
|
if ret_value is not None:
|
|
if ret_value.value != 0:
|
|
logger.error(f'{name} failed.')
|
|
exit(1)
|
|
else:
|
|
logger.info(f'{name} success.')
|
|
|
|
|
|
def torch2ir(ir_type: IR):
|
|
"""Return the conversion function from torch to the intermediate
|
|
representation.
|
|
|
|
Args:
|
|
ir_type (IR): The type of the intermediate representation.
|
|
"""
|
|
if ir_type == IR.ONNX:
|
|
return torch2onnx
|
|
elif ir_type == IR.TORCHSCRIPT:
|
|
return torch2torchscript
|
|
else:
|
|
raise KeyError(f'Unexpected IR type {ir_type}')
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
set_start_method('spawn', force=True)
|
|
logger = get_root_logger()
|
|
log_level = logging.getLevelName(args.log_level)
|
|
logger.setLevel(log_level)
|
|
|
|
pipeline_funcs = [
|
|
torch2onnx, torch2torchscript, extract_model, create_calib_input_data
|
|
]
|
|
PIPELINE_MANAGER.enable_multiprocess(True, pipeline_funcs)
|
|
PIPELINE_MANAGER.set_log_level(log_level, pipeline_funcs)
|
|
|
|
deploy_cfg_path = args.deploy_cfg
|
|
model_cfg_path = args.model_cfg
|
|
checkpoint_path = args.checkpoint
|
|
quant = args.quant
|
|
quant_image_dir = args.quant_image_dir
|
|
|
|
# load deploy_cfg
|
|
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
|
|
|
|
# create work_dir if not
|
|
mmengine.mkdir_or_exist(osp.abspath(args.work_dir))
|
|
|
|
if args.dump_info:
|
|
export2SDK(
|
|
deploy_cfg,
|
|
model_cfg,
|
|
args.work_dir,
|
|
pth=checkpoint_path,
|
|
device=args.device)
|
|
|
|
ret_value = mp.Value('d', 0, lock=False)
|
|
|
|
# convert to IR
|
|
ir_config = get_ir_config(deploy_cfg)
|
|
ir_save_file = ir_config['save_file']
|
|
ir_type = IR.get(ir_config['type'])
|
|
torch2ir(ir_type)(
|
|
args.img,
|
|
args.work_dir,
|
|
ir_save_file,
|
|
deploy_cfg_path,
|
|
model_cfg_path,
|
|
checkpoint_path,
|
|
device=args.device)
|
|
|
|
# convert backend
|
|
ir_files = [osp.join(args.work_dir, ir_save_file)]
|
|
|
|
# partition model
|
|
partition_cfgs = get_partition_config(deploy_cfg)
|
|
|
|
if partition_cfgs is not None:
|
|
|
|
if 'partition_cfg' in partition_cfgs:
|
|
partition_cfgs = partition_cfgs.get('partition_cfg', None)
|
|
else:
|
|
assert 'type' in partition_cfgs
|
|
partition_cfgs = get_predefined_partition_cfg(
|
|
deploy_cfg, partition_cfgs['type'])
|
|
|
|
origin_ir_file = ir_files[0]
|
|
ir_files = []
|
|
for partition_cfg in partition_cfgs:
|
|
save_file = partition_cfg['save_file']
|
|
save_path = osp.join(args.work_dir, save_file)
|
|
start = partition_cfg['start']
|
|
end = partition_cfg['end']
|
|
dynamic_axes = partition_cfg.get('dynamic_axes', None)
|
|
|
|
extract_model(
|
|
origin_ir_file,
|
|
start,
|
|
end,
|
|
dynamic_axes=dynamic_axes,
|
|
save_file=save_path)
|
|
|
|
ir_files.append(save_path)
|
|
|
|
# calib data
|
|
calib_filename = get_calib_filename(deploy_cfg)
|
|
if calib_filename is not None:
|
|
calib_path = osp.join(args.work_dir, calib_filename)
|
|
create_calib_input_data(
|
|
calib_path,
|
|
deploy_cfg_path,
|
|
model_cfg_path,
|
|
checkpoint_path,
|
|
dataset_cfg=args.calib_dataset_cfg,
|
|
dataset_type='val',
|
|
device=args.device)
|
|
|
|
backend_files = ir_files
|
|
# convert backend
|
|
backend = get_backend(deploy_cfg)
|
|
|
|
# preprocess deploy_cfg
|
|
if backend == Backend.RKNN:
|
|
# TODO: Add this to task_processor in the future
|
|
import tempfile
|
|
|
|
from mmdeploy.utils import (get_common_config, get_normalization,
|
|
get_quantization_config,
|
|
get_rknn_quantization)
|
|
quantization_cfg = get_quantization_config(deploy_cfg)
|
|
common_params = get_common_config(deploy_cfg)
|
|
if get_rknn_quantization(deploy_cfg) is True:
|
|
transform = get_normalization(model_cfg)
|
|
common_params.update(
|
|
dict(
|
|
mean_values=[transform['mean']],
|
|
std_values=[transform['std']]))
|
|
|
|
dataset_file = tempfile.NamedTemporaryFile(suffix='.txt').name
|
|
with open(dataset_file, 'w') as f:
|
|
f.writelines([osp.abspath(args.img)])
|
|
if quantization_cfg.get('dataset', None) is None:
|
|
quantization_cfg['dataset'] = dataset_file
|
|
if backend == Backend.ASCEND:
|
|
# TODO: Add this to backend manager in the future
|
|
if args.dump_info:
|
|
from mmdeploy.backend.ascend import update_sdk_pipeline
|
|
update_sdk_pipeline(args.work_dir)
|
|
|
|
if backend == Backend.VACC:
|
|
# TODO: Add this to task_processor in the future
|
|
|
|
from onnx2vacc_quant_dataset import get_quant
|
|
|
|
from mmdeploy.utils import get_model_inputs
|
|
|
|
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
|
|
model_inputs = get_model_inputs(deploy_cfg)
|
|
|
|
for onnx_path, model_input in zip(ir_files, model_inputs):
|
|
|
|
quant_mode = model_input.get('qconfig', {}).get('dtype', 'fp16')
|
|
assert quant_mode in ['int8',
|
|
'fp16'], quant_mode + ' not support now'
|
|
shape_dict = model_input.get('shape', {})
|
|
|
|
if quant_mode == 'int8':
|
|
create_process(
|
|
'vacc quant dataset',
|
|
target=get_quant,
|
|
args=(deploy_cfg, model_cfg, shape_dict, checkpoint_path,
|
|
args.work_dir, args.device),
|
|
kwargs=dict(),
|
|
ret_value=ret_value)
|
|
|
|
# convert to backend
|
|
PIPELINE_MANAGER.set_log_level(log_level, [to_backend])
|
|
if backend == Backend.TENSORRT:
|
|
PIPELINE_MANAGER.enable_multiprocess(True, [to_backend])
|
|
backend_files = to_backend(
|
|
backend,
|
|
ir_files,
|
|
work_dir=args.work_dir,
|
|
deploy_cfg=deploy_cfg,
|
|
log_level=log_level,
|
|
device=args.device,
|
|
uri=args.uri)
|
|
|
|
# ncnn quantization
|
|
if backend == Backend.NCNN and quant:
|
|
from onnx2ncnn_quant_table import get_table
|
|
|
|
from mmdeploy.apis.ncnn import get_quant_model_file, ncnn2int8
|
|
model_param_paths = backend_files[::2]
|
|
model_bin_paths = backend_files[1::2]
|
|
backend_files = []
|
|
for onnx_path, model_param_path, model_bin_path in zip(
|
|
ir_files, model_param_paths, model_bin_paths):
|
|
|
|
deploy_cfg, model_cfg = load_config(deploy_cfg_path,
|
|
model_cfg_path)
|
|
quant_onnx, quant_table, quant_param, quant_bin = get_quant_model_file( # noqa: E501
|
|
onnx_path, args.work_dir)
|
|
|
|
create_process(
|
|
'ncnn quant table',
|
|
target=get_table,
|
|
args=(onnx_path, deploy_cfg, model_cfg, quant_onnx,
|
|
quant_table, quant_image_dir, args.device),
|
|
kwargs=dict(),
|
|
ret_value=ret_value)
|
|
|
|
create_process(
|
|
'ncnn_int8',
|
|
target=ncnn2int8,
|
|
args=(model_param_path, model_bin_path, quant_table,
|
|
quant_param, quant_bin),
|
|
kwargs=dict(),
|
|
ret_value=ret_value)
|
|
backend_files += [quant_param, quant_bin]
|
|
|
|
if args.test_img is None:
|
|
args.test_img = args.img
|
|
|
|
extra = dict(
|
|
backend=backend,
|
|
output_file=osp.join(args.work_dir, f'output_{backend.value}.jpg'),
|
|
show_result=args.show)
|
|
if backend == Backend.SNPE:
|
|
extra['uri'] = args.uri
|
|
|
|
# get backend inference result, try render
|
|
create_process(
|
|
f'visualize {backend.value} model',
|
|
target=visualize_model,
|
|
args=(model_cfg_path, deploy_cfg_path, backend_files, args.test_img,
|
|
args.device),
|
|
kwargs=extra,
|
|
ret_value=ret_value)
|
|
|
|
# get pytorch model inference result, try visualize if possible
|
|
create_process(
|
|
'visualize pytorch model',
|
|
target=visualize_model,
|
|
args=(model_cfg_path, deploy_cfg_path, [checkpoint_path],
|
|
args.test_img, args.device),
|
|
kwargs=dict(
|
|
backend=Backend.PYTORCH,
|
|
output_file=osp.join(args.work_dir, 'output_pytorch.jpg'),
|
|
show_result=args.show),
|
|
ret_value=ret_value)
|
|
logger.info('All process success.')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|