mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
dump cls_name (#50)
This commit is contained in:
parent
1a040036cc
commit
37a1b83567
@ -160,14 +160,21 @@ def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
|||||||
input_map=input_map)
|
input_map=input_map)
|
||||||
|
|
||||||
|
|
||||||
task_component = {
|
task_map = {
|
||||||
Task.CLASSIFICATION.value: 'LinearClsHead',
|
Task.CLASSIFICATION:
|
||||||
Task.OBJECT_DETECTION.value: 'ResizeBBox',
|
dict(component='LinearClsHead', cls_name='Classifier'),
|
||||||
Task.INSTANCE_SEGMENTATION.value: 'ResizeInstanceMask',
|
Task.OBJECT_DETECTION:
|
||||||
Task.SEGMENTATION.value: 'ResizeMask',
|
dict(component='ResizeBBox', cls_name='Detector'),
|
||||||
Task.SUPER_RESOLUTION.value: 'TensorToImg',
|
Task.INSTANCE_SEGMENTATION:
|
||||||
Task.TEXT_DETECTION.value: 'TextDetHead',
|
dict(component='ResizeInstanceMask', cls_name='Detector'),
|
||||||
Task.TEXT_RECOGNITION.value: 'CTCConvertor'
|
Task.SEGMENTATION:
|
||||||
|
dict(component='ResizeMask', cls_name='Segmentor'),
|
||||||
|
Task.SUPER_RESOLUTION:
|
||||||
|
dict(component='TensorToImg', cls_name='Restorer'),
|
||||||
|
Task.TEXT_DETECTION:
|
||||||
|
dict(component='TextDetHead', cls_name='TextDetector'),
|
||||||
|
Task.TEXT_RECOGNITION:
|
||||||
|
dict(component='CTCConvertor', cls_name='TextRecognizer')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -245,7 +252,7 @@ def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config) -> Dict:
|
|||||||
if task == Task.OBJECT_DETECTION and 'mask_thr_binary' in params:
|
if task == Task.OBJECT_DETECTION and 'mask_thr_binary' in params:
|
||||||
task = Task.INSTANCE_SEGMENTATION
|
task = Task.INSTANCE_SEGMENTATION
|
||||||
|
|
||||||
component = task_component[task.value]
|
component = task_map[task]['component']
|
||||||
if task != Task.SUPER_RESOLUTION and task != Task.SEGMENTATION:
|
if task != Task.SUPER_RESOLUTION and task != Task.SEGMENTATION:
|
||||||
if 'type' in params:
|
if 'type' in params:
|
||||||
component = params.pop('type')
|
component = params.pop('type')
|
||||||
@ -269,14 +276,16 @@ def get_deploy(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
|||||||
work_dir (str): Work dir to save json files.
|
work_dir (str): Work dir to save json files.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
dict: Composed of version, models and customs.
|
dict: Composed of version, task, models and customs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
task = get_task_type(deploy_cfg)
|
||||||
|
cls_name = task_map[task]['cls_name']
|
||||||
_, customs = get_model_name_customs(
|
_, customs = get_model_name_customs(
|
||||||
deploy_cfg, model_cfg, work_dir=work_dir)
|
deploy_cfg, model_cfg, work_dir=work_dir)
|
||||||
version = get_mmdpeloy_version()
|
version = get_mmdpeloy_version()
|
||||||
models = get_models(deploy_cfg, model_cfg, work_dir=work_dir)
|
models = get_models(deploy_cfg, model_cfg, work_dir=work_dir)
|
||||||
return dict(version=version, models=models, customs=customs)
|
return dict(version=version, task=cls_name, models=models, customs=customs)
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
def get_pipeline(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user