From 37a1b83567210fbf14bc55c332d5bf691465c260 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Fri, 7 Jan 2022 15:03:01 +0800 Subject: [PATCH] dump cls_name (#50) --- mmdeploy/utils/export_info.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/mmdeploy/utils/export_info.py b/mmdeploy/utils/export_info.py index ebff78d14..82c6c7c2a 100644 --- a/mmdeploy/utils/export_info.py +++ b/mmdeploy/utils/export_info.py @@ -160,14 +160,21 @@ def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, input_map=input_map) -task_component = { - Task.CLASSIFICATION.value: 'LinearClsHead', - Task.OBJECT_DETECTION.value: 'ResizeBBox', - Task.INSTANCE_SEGMENTATION.value: 'ResizeInstanceMask', - Task.SEGMENTATION.value: 'ResizeMask', - Task.SUPER_RESOLUTION.value: 'TensorToImg', - Task.TEXT_DETECTION.value: 'TextDetHead', - Task.TEXT_RECOGNITION.value: 'CTCConvertor' +task_map = { + Task.CLASSIFICATION: + dict(component='LinearClsHead', cls_name='Classifier'), + Task.OBJECT_DETECTION: + dict(component='ResizeBBox', cls_name='Detector'), + Task.INSTANCE_SEGMENTATION: + dict(component='ResizeInstanceMask', cls_name='Detector'), + Task.SEGMENTATION: + dict(component='ResizeMask', cls_name='Segmentor'), + Task.SUPER_RESOLUTION: + dict(component='TensorToImg', cls_name='Restorer'), + Task.TEXT_DETECTION: + dict(component='TextDetHead', cls_name='TextDetector'), + Task.TEXT_RECOGNITION: + dict(component='CTCConvertor', cls_name='TextRecognizer') } @@ -245,7 +252,7 @@ def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config) -> Dict: if task == Task.OBJECT_DETECTION and 'mask_thr_binary' in params: task = Task.INSTANCE_SEGMENTATION - component = task_component[task.value] + component = task_map[task]['component'] if task != Task.SUPER_RESOLUTION and task != Task.SEGMENTATION: if 'type' in params: component = params.pop('type') @@ -269,14 +276,16 @@ def get_deploy(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, work_dir (str): Work dir to save json files. Return: - dict: Composed of version, models and customs. + dict: Composed of version, task, models and customs. """ + task = get_task_type(deploy_cfg) + cls_name = task_map[task]['cls_name'] _, customs = get_model_name_customs( deploy_cfg, model_cfg, work_dir=work_dir) version = get_mmdpeloy_version() models = get_models(deploy_cfg, model_cfg, work_dir=work_dir) - return dict(version=version, models=models, customs=customs) + return dict(version=version, task=cls_name, models=models, customs=customs) def get_pipeline(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,