mmdeploy/tools/deploy.py
q.yao 4c0b36b7ff
[Refactor] Refactor config v1 (#80)
* [Refactor] Refactor configs according to new standard (#67)

* modify cfg and cfg_util

* modify tensorrt config

* fix bug

* lint

* Fix

1. Delete print
2. Modify the return value from "False, None" to "None" and related code
3. Rename 2 get functions

* modify apply_marks

* [Feature] Refactor ocr config (#71)

* add text detection config refactor

* add text recognition refactor

* add static exporting for mmocr

* fix lint

* set max space in child config

* use Sequence[int] instead

* add assert input_shape

* fix static bug and add ppl ort and trt static (#77)

* [Feature] Refine setup.py (#61)

* add setup.py and related files

* lint

* Edit requirements

* modify onnx version

* modify according to comments

* [Refactor] Refactor mmseg configs  (#73)

* refactor mmseg config

* change create_input

* fix lint

* fix lint

* fix lint

* fix yapf

* fix yapf

* update export

* remove Segmentation

* remove tast assert

* add onnx_config

* remove hardcode

* Inherit with static

* Remove blank line

* Add segmentation task enum

* add assert task

* mmocr version 0.3.0 (#79)

* add dump_info

* [Feature]: Refactor config in mmdet (#75)

* support onnxruntime

* add two stage

* test two-stage ort and ppl

* update fcos post_params

* fix calib

* test ok with maskrcnn dynamic

* add empty line

* add static into config filename

* add input_shape to create_input in mmdet

* add static to some configs

* remove todo codes

* remove partition config in base

* refactor create_input

* rename task name in mmdet

* return None if input_shape is None

* add size info into mmdet configs filenames

* reorganize mmdet configs

* add object detection task for mmdet

* rename get_mmdet_params

* keep naming style consistent

* update post_params for fcos

* fix typo in ncnn config

* [Refactor] Refactor mmedit static config (#78)

* add static cfg

* update create_input

* [Refactor]: Refactor mmcls configs (#74)

* refactor mmcls2.0

* fix classify_tensorrt_dynamic.py

* fix classify_tensorrt_dynmic.py

* classify_tensorrt_dynamic_int8.py

* fix file name

* fix ncnn ppl

* updata prepare_input.py

* update utils.py

* updata constant.py

* add

* fix prepare_input.py

* fix prepare_input.py

* add static config file

* add blank lines

* fix prepare_input.py(wait test)

* fix input_shape(wait test)

* Update prepare_input.py

* fix classification_tensorrt_dynamic(wait test)

* fix classification_tensorrt_dynamic_int8(wait test)

* fix classification_tensorrt_static_int8(wait test)

* Rename classification_tensorrt_dynamic.py to classification_tensorrt_dynamic-224x224-224x224.py

* Rename classification_tensorrt_dynamic_int8.py to classification_tensorrt_dynamic_int8-224x224-224x224.py

* Rename classification_tensorrt_dynamic_int8-224x224-224x224.py to classification_tensorrt_int8_dynamic_224x224-224x224.py

* Rename classification_tensorrt_dynamic-224x224-224x224.py to classification_tensorrt_dynamic_224x224-224x224.py

* Rename classification_tensorrt_static.py to classification_tensorrt_static_224x224.py

* Rename classification_tensorrt_static_int8.py to classification_tensorrt_int8_static_224x224.py

* Update prepare_input.py

* Rename classification_tensorrt_dynamic_224x224-224x224.py to classification_tensorrt_dynamic-224x224-224x224.py

* Rename classification_tensorrt_int8_dynamic_224x224-224x224.py to classification_tensorrt_int8-dynamic_224x224-224x224.py

* Rename classification_tensorrt_int8-dynamic_224x224-224x224.py to classification_tensorrt_int8_dynamic-224x224-224x224.py

* Rename classification_tensorrt_int8_static_224x224.py to classification_tensorrt_int8_static-224x224.py

* Rename classification_tensorrt_static_224x224.py to classification_tensorrt_static-224x224.py

* Update prepare_input.py

* Update prepare_input.py

* Update prepare_input.py

* Update prepare_input.py

* Update prepare_input.py

* Update prepare_input.py

* Update prepare_input.py

* change logging msg

Co-authored-by: maningsheng <mnsheng@yeah.net>

* fix

* fix else branch

* fix bug for trt in mmseg

* enable dump trt info

* fix trt static for mmdet

* remove two-stage_partition_tensorrt_static-800x1344 config

* fix wrong backend in ppl config

* fix partition calibration

Co-authored-by: Yifan Zhou <singlezombie@163.com>
Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com>
Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
Co-authored-by: RunningLeon <maningsheng@sensetime.com>
Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com>
Co-authored-by: maningsheng <mnsheng@yeah.net>
Co-authored-by: AllentDan <AllentDan@yeah.net>
2021-09-16 10:26:09 +08:00

247 lines
8.2 KiB
Python

import argparse
import logging
import os.path as osp
import subprocess
from functools import partial
import mmcv
import torch.multiprocessing as mp
from torch.multiprocessing import Process, set_start_method
from mmdeploy.apis import (create_calib_table, extract_model, inference_model,
torch2onnx)
from mmdeploy.apis.utils import get_partition_cfg as parse_partition_cfg
from mmdeploy.utils import (Backend, get_backend, get_calib_filename,
get_codebase, get_model_inputs, get_onnx_config,
get_partition_config, load_config)
from mmdeploy.utils.export_info import dump_info
def parse_args():
parser = argparse.ArgumentParser(description='Export model to backends.')
parser.add_argument('deploy_cfg', help='deploy config path')
parser.add_argument('model_cfg', help='model config path')
parser.add_argument('checkpoint', help='model checkpoint path')
parser.add_argument('img', help='image used to convert model model')
parser.add_argument(
'--test-img', default=None, help='image used to test model')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--calib-dataset-cfg',
help='dataset config path used to calibrate.',
default=None)
parser.add_argument(
'--device', help='device used for conversion', default='cpu')
parser.add_argument(
'--log-level',
help='set log level',
default='INFO',
choices=list(logging._nameToLevel.keys()))
parser.add_argument(
'--show', action='store_true', help='Show detection outputs')
parser.add_argument(
'--dump-info', action='store_true', help='Output information for SDK')
args = parser.parse_args()
return args
def target_wrapper(target, log_level, ret_value, *args, **kwargs):
logger = logging.getLogger()
logger.level
logger.setLevel(log_level)
if ret_value is not None:
ret_value.value = -1
try:
result = target(*args, **kwargs)
if ret_value is not None:
ret_value.value = 0
return result
except Exception as e:
logging.error(e)
def create_process(name, target, args, kwargs, ret_value=None):
logging.info(f'{name} start.')
log_level = logging.getLogger().level
wrap_func = partial(target_wrapper, target, log_level, ret_value)
process = Process(target=wrap_func, args=args, kwargs=kwargs)
process.start()
process.join()
if ret_value is not None:
if ret_value.value != 0:
logging.error(f'{name} failed.')
exit()
else:
logging.info(f'{name} success.')
def main():
args = parse_args()
set_start_method('spawn')
logger = logging.getLogger()
logger.setLevel(args.log_level)
deploy_cfg_path = args.deploy_cfg
model_cfg_path = args.model_cfg
checkpoint_path = args.checkpoint
# load deploy_cfg
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
if args.dump_info:
dump_info(deploy_cfg, model_cfg, args.work_dir)
# create work_dir if not
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
ret_value = mp.Value('d', 0, lock=False)
# convert onnx
onnx_save_file = get_onnx_config(deploy_cfg)['save_file']
create_process(
'torch2onnx',
target=torch2onnx,
args=(args.img, args.work_dir, onnx_save_file, deploy_cfg_path,
model_cfg_path, checkpoint_path),
kwargs=dict(device=args.device),
ret_value=ret_value)
# convert backend
onnx_files = [osp.join(args.work_dir, onnx_save_file)]
# partition model
partition_cfgs = get_partition_config(deploy_cfg)
if partition_cfgs is not None:
if 'partition_cfg' in partition_cfgs:
partition_cfgs = partition_cfgs.get('partition_cfg', None)
else:
assert 'type' in partition_cfgs
partition_cfgs = parse_partition_cfg(
get_codebase(deploy_cfg), partition_cfgs['type'])
origin_onnx_file = onnx_files[0]
onnx_files = []
for partition_cfg in partition_cfgs:
save_file = partition_cfg['save_file']
save_path = osp.join(args.work_dir, save_file)
start = partition_cfg['start']
end = partition_cfg['end']
dynamic_axes = partition_cfg.get('dynamic_axes', None)
create_process(
f'partition model {save_file} with start: {start}, end: {end}',
extract_model,
args=(origin_onnx_file, start, end),
kwargs=dict(dynamic_axes=dynamic_axes, save_file=save_path),
ret_value=ret_value)
onnx_files.append(save_path)
# calib data
calib_filename = get_calib_filename(deploy_cfg)
if calib_filename is not None:
calib_path = osp.join(args.work_dir, calib_filename)
create_process(
'calibration',
create_calib_table,
args=(calib_path, deploy_cfg_path, model_cfg_path,
checkpoint_path),
kwargs=dict(
dataset_cfg=args.calib_dataset_cfg,
dataset_type='val',
device=args.device),
ret_value=ret_value)
backend_files = onnx_files
# convert backend
backend = get_backend(deploy_cfg, 'default')
if backend == Backend.TENSORRT:
model_params = get_model_inputs(deploy_cfg)
assert len(model_params) == len(onnx_files)
from mmdeploy.apis.tensorrt import is_available as trt_is_available
from mmdeploy.apis.tensorrt import onnx2tensorrt
assert trt_is_available(
), 'TensorRT is not available,' \
+ ' please install TensorRT and build TensorRT custom ops first.'
backend_files = []
for model_id, model_param, onnx_path in zip(
range(len(onnx_files)), model_params, onnx_files):
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
save_file = model_param.get('save_file', onnx_name + '.engine')
partition_type = 'end2end' if partition_cfgs is None \
else onnx_name
create_process(
f'onnx2tensorrt of {onnx_path}',
target=onnx2tensorrt,
args=(args.work_dir, save_file, model_id, deploy_cfg_path,
onnx_path),
kwargs=dict(device=args.device, partition_type=partition_type),
ret_value=ret_value)
backend_files.append(osp.join(args.work_dir, save_file))
elif backend == Backend.NCNN:
from mmdeploy.apis.ncnn import get_onnx2ncnn_path
from mmdeploy.apis.ncnn import is_available as is_available_ncnn
if not is_available_ncnn():
logging.error('ncnn support is not available.')
exit(-1)
onnx2ncnn_path = get_onnx2ncnn_path()
backend_files = []
for onnx_path in onnx_files:
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
save_param = onnx_name + '.param'
save_bin = onnx_name + '.bin'
save_param = osp.join(args.work_dir, save_param)
save_bin = osp.join(args.work_dir, save_bin)
subprocess.call([onnx2ncnn_path, onnx_path, save_param, save_bin])
backend_files += [save_param, save_bin]
if args.test_img is None:
args.test_img = args.img
# visualize model of the backend
create_process(
f'visualize {backend.value} model',
target=inference_model,
args=(model_cfg_path, deploy_cfg_path, backend_files, args.test_img,
args.device),
kwargs=dict(
backend=backend,
output_file=f'output_{backend.value}.jpg',
show_result=args.show),
ret_value=ret_value)
# visualize pytorch model
create_process(
'visualize pytorch model',
target=inference_model,
args=(model_cfg_path, deploy_cfg_path, [checkpoint_path],
args.test_img, args.device),
kwargs=dict(
backend=Backend.PYTORCH,
output_file='output_pytorch.jpg',
show_result=args.show),
ret_value=ret_value)
logging.info('All process success.')
if __name__ == '__main__':
main()