dump cls_name (#50)

This commit is contained in:
AllentDan 2022-01-07 15:03:01 +08:00 committed by GitHub
parent 1a040036cc
commit 37a1b83567
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -160,14 +160,21 @@ def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
input_map=input_map)
task_component = {
Task.CLASSIFICATION.value: 'LinearClsHead',
Task.OBJECT_DETECTION.value: 'ResizeBBox',
Task.INSTANCE_SEGMENTATION.value: 'ResizeInstanceMask',
Task.SEGMENTATION.value: 'ResizeMask',
Task.SUPER_RESOLUTION.value: 'TensorToImg',
Task.TEXT_DETECTION.value: 'TextDetHead',
Task.TEXT_RECOGNITION.value: 'CTCConvertor'
task_map = {
Task.CLASSIFICATION:
dict(component='LinearClsHead', cls_name='Classifier'),
Task.OBJECT_DETECTION:
dict(component='ResizeBBox', cls_name='Detector'),
Task.INSTANCE_SEGMENTATION:
dict(component='ResizeInstanceMask', cls_name='Detector'),
Task.SEGMENTATION:
dict(component='ResizeMask', cls_name='Segmentor'),
Task.SUPER_RESOLUTION:
dict(component='TensorToImg', cls_name='Restorer'),
Task.TEXT_DETECTION:
dict(component='TextDetHead', cls_name='TextDetector'),
Task.TEXT_RECOGNITION:
dict(component='CTCConvertor', cls_name='TextRecognizer')
}
@ -245,7 +252,7 @@ def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config) -> Dict:
if task == Task.OBJECT_DETECTION and 'mask_thr_binary' in params:
task = Task.INSTANCE_SEGMENTATION
component = task_component[task.value]
component = task_map[task]['component']
if task != Task.SUPER_RESOLUTION and task != Task.SEGMENTATION:
if 'type' in params:
component = params.pop('type')
@ -269,14 +276,16 @@ def get_deploy(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
work_dir (str): Work dir to save json files.
Return:
dict: Composed of version, models and customs.
dict: Composed of version, task, models and customs.
"""
task = get_task_type(deploy_cfg)
cls_name = task_map[task]['cls_name']
_, customs = get_model_name_customs(
deploy_cfg, model_cfg, work_dir=work_dir)
version = get_mmdpeloy_version()
models = get_models(deploy_cfg, model_cfg, work_dir=work_dir)
return dict(version=version, models=models, customs=customs)
return dict(version=version, task=cls_name, models=models, customs=customs)
def get_pipeline(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,