mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
dump cls_name (#50)
This commit is contained in:
parent
1a040036cc
commit
37a1b83567
@ -160,14 +160,21 @@ def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
||||
input_map=input_map)
|
||||
|
||||
|
||||
task_component = {
|
||||
Task.CLASSIFICATION.value: 'LinearClsHead',
|
||||
Task.OBJECT_DETECTION.value: 'ResizeBBox',
|
||||
Task.INSTANCE_SEGMENTATION.value: 'ResizeInstanceMask',
|
||||
Task.SEGMENTATION.value: 'ResizeMask',
|
||||
Task.SUPER_RESOLUTION.value: 'TensorToImg',
|
||||
Task.TEXT_DETECTION.value: 'TextDetHead',
|
||||
Task.TEXT_RECOGNITION.value: 'CTCConvertor'
|
||||
task_map = {
|
||||
Task.CLASSIFICATION:
|
||||
dict(component='LinearClsHead', cls_name='Classifier'),
|
||||
Task.OBJECT_DETECTION:
|
||||
dict(component='ResizeBBox', cls_name='Detector'),
|
||||
Task.INSTANCE_SEGMENTATION:
|
||||
dict(component='ResizeInstanceMask', cls_name='Detector'),
|
||||
Task.SEGMENTATION:
|
||||
dict(component='ResizeMask', cls_name='Segmentor'),
|
||||
Task.SUPER_RESOLUTION:
|
||||
dict(component='TensorToImg', cls_name='Restorer'),
|
||||
Task.TEXT_DETECTION:
|
||||
dict(component='TextDetHead', cls_name='TextDetector'),
|
||||
Task.TEXT_RECOGNITION:
|
||||
dict(component='CTCConvertor', cls_name='TextRecognizer')
|
||||
}
|
||||
|
||||
|
||||
@ -245,7 +252,7 @@ def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config) -> Dict:
|
||||
if task == Task.OBJECT_DETECTION and 'mask_thr_binary' in params:
|
||||
task = Task.INSTANCE_SEGMENTATION
|
||||
|
||||
component = task_component[task.value]
|
||||
component = task_map[task]['component']
|
||||
if task != Task.SUPER_RESOLUTION and task != Task.SEGMENTATION:
|
||||
if 'type' in params:
|
||||
component = params.pop('type')
|
||||
@ -269,14 +276,16 @@ def get_deploy(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
||||
work_dir (str): Work dir to save json files.
|
||||
|
||||
Return:
|
||||
dict: Composed of version, models and customs.
|
||||
dict: Composed of version, task, models and customs.
|
||||
"""
|
||||
|
||||
task = get_task_type(deploy_cfg)
|
||||
cls_name = task_map[task]['cls_name']
|
||||
_, customs = get_model_name_customs(
|
||||
deploy_cfg, model_cfg, work_dir=work_dir)
|
||||
version = get_mmdpeloy_version()
|
||||
models = get_models(deploy_cfg, model_cfg, work_dir=work_dir)
|
||||
return dict(version=version, models=models, customs=customs)
|
||||
return dict(version=version, task=cls_name, models=models, customs=customs)
|
||||
|
||||
|
||||
def get_pipeline(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
||||
|
Loading…
x
Reference in New Issue
Block a user