From cad0092801a56d9d89d4dea114e9a22de3c62cc0 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Wed, 17 Aug 2022 18:01:26 +0800 Subject: [PATCH] [Fix] fix device error in dump-info (#912) * fix device error in dump-info * fix UT --- mmdeploy/backend/sdk/export_info.py | 59 +++++++++++++++++++---------- tests/test_utils/test_util.py | 2 +- tools/deploy.py | 7 +++- 3 files changed, 45 insertions(+), 23 deletions(-) diff --git a/mmdeploy/backend/sdk/export_info.py b/mmdeploy/backend/sdk/export_info.py index 8e68f413f..ded60c374 100644 --- a/mmdeploy/backend/sdk/export_info.py +++ b/mmdeploy/backend/sdk/export_info.py @@ -42,20 +42,21 @@ def get_task(deploy_cfg: mmcv.Config) -> Dict: def get_model_name_customs(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, - work_dir: str) -> Tuple: + work_dir: str, device: str) -> Tuple: """Get the model name and dump custom file. Args: deploy_cfg (mmcv.Config): Deploy config dict. model_cfg (mmcv.Config): The model config dict. work_dir (str): Work dir to save json files. + device (str): The device passed in. Return: tuple(): Composed of the model name and the custom info. """ task = get_task_type(deploy_cfg) task_processor = build_task_processor( - model_cfg=model_cfg, deploy_cfg=deploy_cfg, device='cpu') + model_cfg=model_cfg, deploy_cfg=deploy_cfg, device=device) name = task_processor.get_model_name() customs = [] if task == Task.TEXT_RECOGNITION: @@ -75,19 +76,21 @@ def get_model_name_customs(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, def get_models(deploy_cfg: Union[str, mmcv.Config], - model_cfg: Union[str, mmcv.Config], work_dir: str) -> List: + model_cfg: Union[str, mmcv.Config], work_dir: str, + device: str) -> List: """Get the output model informantion for deploy.json. Args: deploy_cfg (mmcv.Config): Deploy config dict. model_cfg (mmcv.Config): The model config dict. work_dir (str): Work dir to save json files. + device (str): The device passed in. Return: list[dict]: The list contains dicts composed of the model name, net, weghts, backend, precision batchsize and dynamic_shape. """ - name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir) + name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir, device) precision = 'FP32' ir_name = get_ir_config(deploy_cfg)['save_file'] net = ir_name @@ -149,19 +152,20 @@ def get_models(deploy_cfg: Union[str, mmcv.Config], def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, - work_dir: str) -> Dict: + work_dir: str, device: str) -> Dict: """Get the inference information for pipeline.json. Args: deploy_cfg (mmcv.Config): Deploy config dict. model_cfg (mmcv.Config): The model config dict. work_dir (str): Work dir to save json files. + device (str): The device passed in. Return: dict: Composed of the model name, type, module, input, output and input_map. """ - name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir) + name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir, device) type = 'Task' module = 'Net' input = ['prep_output'] @@ -182,9 +186,17 @@ def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, return return_dict -def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config): +def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, + device: str): + """Get the pre process information for pipeline.json. + + Args: + deploy_cfg (mmcv.Config): Deploy config dict. + model_cfg (mmcv.Config): The model config dict. + device (str): The device passed in. + """ task_processor = build_task_processor( - model_cfg=model_cfg, deploy_cfg=deploy_cfg, device='cpu') + model_cfg=model_cfg, deploy_cfg=deploy_cfg, device=device) pipeline = task_processor.get_preprocess() type = 'Task' module = 'Transform' @@ -233,12 +245,14 @@ def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config): transforms=transforms) -def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config) -> Dict: +def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, + device: str, **kwargs) -> Dict: """Get the post process information for pipeline.json. Args: deploy_cfg (mmcv.Config): Deploy config dict. model_cfg (mmcv.Config): The model config dict. + device (str): The device passed in. Return: dict: Composed of the model name, type, module, input, params and @@ -249,7 +263,7 @@ def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config) -> Dict: name = 'postprocess' task = get_task_type(deploy_cfg) task_processor = build_task_processor( - model_cfg=model_cfg, deploy_cfg=deploy_cfg, device='cpu') + model_cfg=model_cfg, deploy_cfg=deploy_cfg, device=device) params = task_processor.get_postprocess() # TODO remove after adding instance segmentation to task processor @@ -270,14 +284,15 @@ def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config) -> Dict: output=output) -def get_deploy(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, - work_dir: str) -> Dict: +def get_deploy(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, work_dir: str, + device: str) -> Dict: """Get the inference information for pipeline.json. Args: deploy_cfg (mmcv.Config): Deploy config dict. model_cfg (mmcv.Config): The model config dict. work_dir (str): Work dir to save json files. + device (str): The device passed in. Return: dict: Composed of version, task, models and customs. @@ -286,27 +301,28 @@ def get_deploy(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, task = get_task_type(deploy_cfg) cls_name = task_map[task]['cls_name'] _, customs = get_model_name_customs( - deploy_cfg, model_cfg, work_dir=work_dir) + deploy_cfg, model_cfg, work_dir=work_dir, device=device) version = get_mmdpeloy_version() - models = get_models(deploy_cfg, model_cfg, work_dir=work_dir) + models = get_models(deploy_cfg, model_cfg, work_dir, device) return dict(version=version, task=cls_name, models=models, customs=customs) def get_pipeline(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, - work_dir: str) -> Dict: + work_dir: str, device: str) -> Dict: """Get the inference information for pipeline.json. Args: deploy_cfg (mmcv.Config): Deploy config dict. model_cfg (mmcv.Config): The model config dict. work_dir (str): Work dir to save json files. + device (str): The device passed in. Return: dict: Composed of input node name, output node name and the tasks. """ - preprocess = get_preprocess(deploy_cfg, model_cfg) - infer_info = get_inference_info(deploy_cfg, model_cfg, work_dir=work_dir) - postprocess = get_postprocess(deploy_cfg, model_cfg) + preprocess = get_preprocess(deploy_cfg, model_cfg, device) + infer_info = get_inference_info(deploy_cfg, model_cfg, work_dir, device) + postprocess = get_postprocess(deploy_cfg, model_cfg, device) task = get_task_type(deploy_cfg) input_names = preprocess['input'] output_names = postprocess['output'] @@ -353,7 +369,8 @@ def get_detail(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, def export2SDK(deploy_cfg: Union[str, mmcv.Config], - model_cfg: Union[str, mmcv.Config], work_dir: str, pth: str): + model_cfg: Union[str, mmcv.Config], work_dir: str, pth: str, + device: str, **kwargs): """Export information to SDK. This function dump `deploy.json`, `pipeline.json` and `detail.json` to work dir. @@ -364,8 +381,8 @@ def export2SDK(deploy_cfg: Union[str, mmcv.Config], pth (str): The path of the model checkpoint weights. """ deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) - deploy_info = get_deploy(deploy_cfg, model_cfg, work_dir=work_dir) - pipeline_info = get_pipeline(deploy_cfg, model_cfg, work_dir=work_dir) + deploy_info = get_deploy(deploy_cfg, model_cfg, work_dir, device) + pipeline_info = get_pipeline(deploy_cfg, model_cfg, work_dir, device) detail_info = get_detail(deploy_cfg, model_cfg, pth=pth) mmcv.dump( deploy_info, diff --git a/tests/test_utils/test_util.py b/tests/test_utils/test_util.py index 9f597d1d7..be273469b 100644 --- a/tests/test_utils/test_util.py +++ b/tests/test_utils/test_util.py @@ -417,7 +417,7 @@ def test_AdvancedEnum(): not importlib.util.find_spec('mmedit'), reason='requires mmedit') def test_export_info(): with tempfile.TemporaryDirectory() as dir: - export2SDK(correct_deploy_cfg, correct_model_cfg, dir, '') + export2SDK(correct_deploy_cfg, correct_model_cfg, dir, '', 'cpu') deploy_json = os.path.join(dir, 'deploy.json') pipeline_json = os.path.join(dir, 'pipeline.json') detail_json = os.path.join(dir, 'detail.json') diff --git a/tools/deploy.py b/tools/deploy.py index 73135a6e3..3dc2f2bf9 100644 --- a/tools/deploy.py +++ b/tools/deploy.py @@ -122,7 +122,12 @@ def main(): mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) if args.dump_info: - export2SDK(deploy_cfg, model_cfg, args.work_dir, pth=checkpoint_path) + export2SDK( + deploy_cfg, + model_cfg, + args.work_dir, + pth=checkpoint_path, + device=args.device) ret_value = mp.Value('d', 0, lock=False)