[Feture] Export preprocess and deploy information to SDK (#65)

* add export info

* add dump-info funciton

* add collect info

* fix lint

* add docstring

* docstring

* docstring
This commit is contained in:
AllentDan 2021-09-13 18:50:29 +08:00 committed by GitHub
parent 745c51f965
commit 10793f488e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 1 deletions

View File

@ -0,0 +1,38 @@
from typing import Union
import mmcv
from mmdeploy.utils import load_config
def dump_info(deploy_cfg: Union[str, mmcv.Config],
model_cfg: Union[str, mmcv.Config], work_dir: str):
"""Export information to SDK.
Args:
deploy_cfg (str | mmcv.Config): deploy config file or dict
model_cfg (str | mmcv.Config): model config file or dict
work_dir (str): work dir to save json files
"""
# TODO dump default values of transformation function to json
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
meta_keys = [
'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg'
]
if 'transforms' in model_cfg.data.test.pipeline[-1]:
model_cfg.data.test.pipeline[-1]['transforms'][-1][
'meta_keys'] = meta_keys
else:
model_cfg.data.test.pipeline[-1]['meta_keys'] = meta_keys
mmcv.dump(
model_cfg.data.test.pipeline,
'{}/preprocess.json'.format(work_dir),
sort_keys=False,
indent=4)
mmcv.dump(
deploy_cfg._cfg_dict,
'{}/deploy_cfg.json'.format(work_dir),
sort_keys=False,
indent=4)

View File

@ -13,6 +13,7 @@ from mmdeploy.apis import (create_calib_table, extract_model, inference_model,
from mmdeploy.apis.utils import get_partition_cfg
from mmdeploy.utils.config_utils import (Backend, get_backend, get_codebase,
load_config)
from mmdeploy.utils.export_info import dump_info
def parse_args():
@ -37,6 +38,8 @@ def parse_args():
choices=list(logging._nameToLevel.keys()))
parser.add_argument(
'--show', action='store_true', help='Show detection outputs')
parser.add_argument(
'--dump-info', action='store_true', help='Output information for SDK')
args = parser.parse_args()
return args
@ -87,7 +90,10 @@ def main():
checkpoint_path = args.checkpoint
# load deploy_cfg
deploy_cfg = load_config(deploy_cfg_path)[0]
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
if args.dump_info:
dump_info(deploy_cfg, model_cfg, args.work_dir, args.img, args.device)
# create work_dir if not
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))