From 633d74fb36aa81579b260818eb72a15e29d20737 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Mon, 19 Sep 2022 17:15:21 +0800 Subject: [PATCH] refactor dump-info (#1040) --- mmdeploy/backend/sdk/export_info.py | 185 ++++-------------- mmdeploy/codebase/base/task.py | 6 +- .../codebase/mmcls/deploy/classification.py | 6 +- .../codebase/mmdet/deploy/object_detection.py | 6 +- .../mmdet3d/deploy/voxel_detection.py | 6 +- .../mmedit/deploy/super_resolution.py | 6 +- .../codebase/mmocr/deploy/text_detection.py | 43 +++- .../codebase/mmocr/deploy/text_recognition.py | 53 ++++- .../codebase/mmpose/deploy/pose_detection.py | 6 +- .../mmrotate/deploy/rotated_detection.py | 6 +- .../codebase/mmseg/deploy/segmentation.py | 6 +- mmdeploy/utils/__init__.py | 4 +- mmdeploy/utils/config_utils.py | 23 +++ 13 files changed, 171 insertions(+), 185 deletions(-) diff --git a/mmdeploy/backend/sdk/export_info.py b/mmdeploy/backend/sdk/export_info.py index 1d5fa69ef..aa7ab332a 100644 --- a/mmdeploy/backend/sdk/export_info.py +++ b/mmdeploy/backend/sdk/export_info.py @@ -6,9 +6,9 @@ from typing import Dict, List, Tuple, Union import mmengine from mmdeploy.apis import build_task_processor -from mmdeploy.utils import (Backend, Task, get_backend, get_codebase, - get_common_config, get_ir_config, get_root_logger, - get_task_type, is_dynamic_batch, load_config) +from mmdeploy.utils import (Task, get_backend, get_codebase, get_ir_config, + get_precision, get_root_logger, get_task_type, + is_dynamic_batch, load_config) from mmdeploy.utils.constants import SDK_TASK_MAP as task_map @@ -79,60 +79,31 @@ def get_models(deploy_cfg: Union[str, mmengine.Config], name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir) precision = 'FP32' ir_name = get_ir_config(deploy_cfg)['save_file'] - net = ir_name weights = '' - backend = get_backend(deploy_cfg=deploy_cfg) + backend = get_backend(deploy_cfg=deploy_cfg).value - def replace_suffix(file_name: str, dst_suffix: str) -> str: - """Replace the suffix to the destination one. - - Args: - file_name (str): The file name to be operated. - dst_suffix (str): The destination suffix. - - Return: - str: The file name of which the suffix has been replaced. - """ - return re.sub(r'\.[a-z]+', dst_suffix, file_name) - - if backend == Backend.TENSORRT: - net = replace_suffix(ir_name, '.engine') - common_cfg = get_common_config(deploy_cfg) - fp16_mode = common_cfg.get('fp16_mode', False) - int8_mode = common_cfg.get('int8_mode', False) - if fp16_mode: - precision = 'FP16' - if int8_mode: - precision = 'INT8' - elif backend == Backend.PPLNN: - precision = 'FP16' - weights = replace_suffix(ir_name, '.json') - net = ir_name - elif backend == Backend.OPENVINO: - net = replace_suffix(ir_name, '.xml') - weights = replace_suffix(ir_name, '.bin') - elif backend == Backend.NCNN: - net = replace_suffix(ir_name, '.param') - weights = replace_suffix(ir_name, '.bin') - if 'precision' in deploy_cfg['backend_config']: - precision = deploy_cfg['backend_config']['precision'] - elif backend == Backend.SNPE: - net = replace_suffix(ir_name, '.dlc') - elif backend in [Backend.ONNXRUNTIME, Backend.TORCHSCRIPT]: - pass - else: - raise NotImplementedError(f'Not supported backend: {backend.value}.') + backend_net = dict( + tensorrt=lambda file: re.sub(r'\.[a-z]+', '.engine', file), + openvino=lambda file: re.sub(r'\.[a-z]+', '.xml', file), + ncnn=lambda file: re.sub(r'\.[a-z]+', '.param', file), + snpe=lambda file: re.sub(r'\.[a-z]+', '.dlc', file)) + backend_weights = dict( + pplnn=lambda file: re.sub(r'\.[a-z]+', '.json', file), + openvino=lambda file: re.sub(r'\.[a-z]+', '.bin', file), + ncnn=lambda file: re.sub(r'\.[a-z]+', '.bin', file)) + net = backend_net.get(backend, lambda x: x)(ir_name) + weights = backend_weights.get(backend, lambda x: weights)(ir_name) + precision = get_precision(deploy_cfg) dynamic_shape = is_dynamic_batch(deploy_cfg, input_name='input') - batch_size = 1 return [ dict( name=name, net=net, weights=weights, - backend=backend.value, + backend=backend, precision=precision, - batch_size=batch_size, + batch_size=1, dynamic_shape=dynamic_shape) ] @@ -151,20 +122,16 @@ def get_inference_info(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config, input_map. """ name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir) - type = 'Task' - module = 'Net' - input = ['prep_output'] - output = ['infer_output'] ir_config = get_ir_config(deploy_cfg) input_names = ir_config.get('input_names', None) input_name = input_names[0] if input_names else 'input' input_map = dict(img=input_name) return_dict = dict( name=name, - type=type, - module=module, - input=input, - output=output, + type='Task', + module='Net', + input=['prep_output'], + output=['infer_output'], input_map=input_map) if 'use_vulkan' in deploy_cfg['backend_config']: return_dict['use_vulkan'] = deploy_cfg['backend_config']['use_vulkan'] @@ -174,71 +141,15 @@ def get_inference_info(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config, def get_preprocess(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config): task_processor = build_task_processor( model_cfg=model_cfg, deploy_cfg=deploy_cfg, device='cpu') - pipeline = task_processor.get_preprocess() - type = 'Task' - module = 'Transform' - name = 'Preprocess' - input = ['img'] - output = ['prep_output'] - meta_keys = [ - 'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape', - 'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg', 'valid_ratio' - ] - transforms = [ - item for item in pipeline - if 'Random' not in item['type'] and 'RescaleToZeroOne' not in - item['type'] and 'Annotation' not in item['type'] - ] - for i, transform in enumerate(transforms): - if 'keys' in transform and transform['keys'] == ['lq']: - transform['keys'] = ['img'] - if 'key' in transform and transform['key'] == 'lq': - transform['key'] = 'img' - if transform['type'] == 'ResizeEdge': - transform['type'] = 'Resize' - transform['keep_ratio'] = True - # now the sdk of class has bugs, because ResizeEdge not implement - # in sdk. - transform['size'] = (transform['scale'], transform['scale']) - if transform['type'] in ('PackTextDetInputs', 'PackTextRecogInputs'): - meta_keys += transform[ - 'meta_keys'] if 'meta_keys' in transform else [] - transform['meta_keys'] = list(set(meta_keys)) - transform['keys'] = ['img'] - transforms[i]['type'] = 'Collect' - if transform['type'] == 'PackDetInputs' or \ - transform['type'] == 'PackClsInputs': - transforms.insert(i, dict(type='DefaultFormatBundle')) - transform['type'] = 'Collect' - if 'keys' not in transform: - transform['keys'] = ['img'] - if transform['type'] == 'Resize': - transforms[i]['size'] = transforms[i]['scale'] - - data_preprocessor = model_cfg.model.data_preprocessor - transforms.insert(-1, dict(type='DefaultFormatBundle')) - transforms.insert( - -2, - dict( - type='Pad', - size_divisor=data_preprocessor.get('pad_size_divisor', 1))) - transforms.insert( - -3, - dict( - type='Normalize', - to_rgb=data_preprocessor.get('bgr_to_rgb', False), - mean=data_preprocessor.get('mean', [0, 0, 0]), - std=data_preprocessor.get('std', [1, 1, 1]))) - - assert transforms[0]['type'] == 'LoadImageFromFile', 'The first item type'\ - ' of pipeline should be LoadImageFromFile' - + transforms = task_processor.get_preprocess() + assert transforms[0]['type'] == 'LoadImageFromFile', 'The first item'\ + ' type of pipeline should be LoadImageFromFile' return dict( - type=type, - module=module, - name=name, - input=input, - output=output, + type='Task', + module='Transform', + name='Preprocess', + input=['img'], + output=['prep_output'], transforms=transforms) @@ -255,38 +166,17 @@ def get_postprocess(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config, dict: Composed of the model name, type, module, input, params and output. """ - module = get_codebase(deploy_cfg).value - type = 'Task' - name = 'postprocess' - params = dict() - task = get_task_type(deploy_cfg) task_processor = build_task_processor( model_cfg=model_cfg, deploy_cfg=deploy_cfg, device='cpu') - post_processor = task_processor.get_postprocess() + post_processor = task_processor.get_postprocess(work_dir) - # TODO remove after adding instance segmentation to task processor - if task == Task.OBJECT_DETECTION and 'mask_thr_binary' in post_processor: - task = Task.INSTANCE_SEGMENTATION - - component = task_map[task]['component'] - if task not in (Task.SUPER_RESOLUTION, Task.SEGMENTATION): - if 'type' in post_processor: - component = post_processor.pop('type') - output = ['post_output'] - - if task == Task.TEXT_RECOGNITION: - import shutil - shutil.copy(model_cfg.dictionary.dict_file, - f'{work_dir}/dict_file.txt') - with_padding = model_cfg.dictionary.get('with_padding', False) - params = dict(dict_file='dict_file.txt', with_padding=with_padding) return dict( - type=type, - module=module, - name=name, - component=component, - params=params, - output=output) + type='Task', + module=get_codebase(deploy_cfg).value, + name='postprocess', + component=post_processor['type'], + params=post_processor.get('params', dict()), + output=['post_output']) def get_deploy(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config, @@ -384,7 +274,6 @@ def export2SDK(deploy_cfg: Union[str, mmengine.Config], pth (str): The path of the model checkpoint weights. """ deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) - print(f'debugging what is type(model_cfg): {type(model_cfg)}') deploy_info = get_deploy(deploy_cfg, model_cfg, work_dir=work_dir) pipeline_info = get_pipeline(deploy_cfg, model_cfg, work_dir=work_dir) detail_info = get_detail(deploy_cfg, model_cfg, pth=pth) diff --git a/mmdeploy/codebase/base/task.py b/mmdeploy/codebase/base/task.py index 93114205c..0f6699906 100644 --- a/mmdeploy/codebase/base/task.py +++ b/mmdeploy/codebase/base/task.py @@ -322,7 +322,7 @@ class BaseTask(metaclass=ABCMeta): return input_data['inputs'] @abstractmethod - def get_preprocess(self) -> Dict: + def get_preprocess(self, *args, **kwargs) -> Dict: """Get the preprocess information for SDK. Return: @@ -331,7 +331,7 @@ class BaseTask(metaclass=ABCMeta): pass @abstractmethod - def get_postprocess(self) -> Dict: + def get_postprocess(self, *args, **kwargs) -> Dict: """Get the postprocess information for SDK. Return: @@ -340,7 +340,7 @@ class BaseTask(metaclass=ABCMeta): pass @abstractmethod - def get_model_name(self) -> str: + def get_model_name(self, *args, **kwargs) -> str: """Get the model name. Return: diff --git a/mmdeploy/codebase/mmcls/deploy/classification.py b/mmdeploy/codebase/mmcls/deploy/classification.py index 9a21382db..3457db00e 100644 --- a/mmdeploy/codebase/mmcls/deploy/classification.py +++ b/mmdeploy/codebase/mmcls/deploy/classification.py @@ -241,7 +241,7 @@ class Classification(BaseTask): """ raise NotImplementedError('Not supported yet.') - def get_preprocess(self) -> Dict: + def get_preprocess(self, *args, **kwargs) -> Dict: """Get the preprocess information for SDK. Return: @@ -252,7 +252,7 @@ class Classification(BaseTask): preprocess = cfg.test_pipeline return preprocess - def get_postprocess(self) -> Dict: + def get_postprocess(self, *args, **kwargs) -> Dict: """Get the postprocess information for SDK. Return: @@ -269,7 +269,7 @@ class Classification(BaseTask): postprocess.topk = max(topk) return postprocess - def get_model_name(self) -> str: + def get_model_name(self, *args, **kwargs) -> str: """Get the model name. Return: diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection.py b/mmdeploy/codebase/mmdet/deploy/object_detection.py index 58a282731..c0de3eb0f 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection.py @@ -206,7 +206,7 @@ class ObjectDetection(BaseTask): f'Unknown partition_type {partition_type}' return MMDET_PARTITION_CFG[partition_type] - def get_preprocess(self) -> Dict: + def get_preprocess(self, *args, **kwargs) -> Dict: """Get the preprocess information for SDK. Return: @@ -217,7 +217,7 @@ class ObjectDetection(BaseTask): preprocess = model_cfg.test_pipeline return preprocess - def get_postprocess(self) -> Dict: + def get_postprocess(self, *args, **kwargs) -> Dict: """Get the postprocess information for SDK. Return: @@ -233,7 +233,7 @@ class ObjectDetection(BaseTask): 'mask_thr_binary'] return postprocess - def get_model_name(self) -> str: + def get_model_name(self, *args, **kwargs) -> str: """Get the model name. Return: diff --git a/mmdeploy/codebase/mmdet3d/deploy/voxel_detection.py b/mmdeploy/codebase/mmdet3d/deploy/voxel_detection.py index fcd5b1031..7a1d7c278 100644 --- a/mmdeploy/codebase/mmdet3d/deploy/voxel_detection.py +++ b/mmdeploy/codebase/mmdet3d/deploy/voxel_detection.py @@ -199,7 +199,7 @@ class VoxelDetection(BaseTask): eval_kwargs.update(dict(metric=metrics, **kwargs)) dataset.evaluate(outputs, **eval_kwargs) - def get_model_name(self) -> str: + def get_model_name(self, *args, **kwargs) -> str: """Get the model name. Return: @@ -230,7 +230,7 @@ class VoxelDetection(BaseTask): """ raise NotImplementedError - def get_postprocess(self) -> Dict: + def get_postprocess(self, *args, **kwargs) -> Dict: """Get the postprocess information for SDK. Return: @@ -238,7 +238,7 @@ class VoxelDetection(BaseTask): """ raise NotImplementedError - def get_preprocess(self) -> Dict: + def get_preprocess(self, *args, **kwargs) -> Dict: """Get the preprocess information for SDK. Return: diff --git a/mmdeploy/codebase/mmedit/deploy/super_resolution.py b/mmdeploy/codebase/mmedit/deploy/super_resolution.py index 2961b6873..14f03821b 100644 --- a/mmdeploy/codebase/mmedit/deploy/super_resolution.py +++ b/mmdeploy/codebase/mmedit/deploy/super_resolution.py @@ -286,7 +286,7 @@ class SuperResolution(BaseTask): for stat in stats: logger.info('Eval-{}: {}'.format(stat, stats[stat])) - def get_preprocess(self) -> Dict: + def get_preprocess(self, *args, **kwargs) -> Dict: """Get the preprocess information for SDK. Return: @@ -300,7 +300,7 @@ class SuperResolution(BaseTask): item['std'] = [255, 255, 255] return preprocess - def get_postprocess(self) -> Dict: + def get_postprocess(self, *args, **kwargs) -> Dict: """Get the postprocess information for SDK. Return: @@ -308,7 +308,7 @@ class SuperResolution(BaseTask): """ return dict() - def get_model_name(self) -> str: + def get_model_name(self, *args, **kwargs) -> str: """Get the model name. Return: diff --git a/mmdeploy/codebase/mmocr/deploy/text_detection.py b/mmdeploy/codebase/mmocr/deploy/text_detection.py index 1105ebbfd..2e5189fdf 100644 --- a/mmdeploy/codebase/mmocr/deploy/text_detection.py +++ b/mmdeploy/codebase/mmocr/deploy/text_detection.py @@ -212,7 +212,7 @@ class TextDetection(BaseTask): """ raise NotImplementedError('Not supported yet.') - def get_preprocess(self) -> Dict: + def get_preprocess(self, *args, **kwargs) -> Dict: """Get the preprocess information for SDK. Return: @@ -220,10 +220,43 @@ class TextDetection(BaseTask): """ input_shape = get_input_shape(self.deploy_cfg) model_cfg = process_model_config(self.model_cfg, [''], input_shape) - preprocess = model_cfg.test_dataloader.dataset.pipeline - return preprocess + pipeline = model_cfg.test_dataloader.dataset.pipeline + meta_keys = [ + 'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape', + 'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg', + 'valid_ratio' + ] + transforms = [ + item for item in pipeline if 'Random' not in item['type'] + and 'Annotation' not in item['type'] + ] + for i, transform in enumerate(transforms): + if transform['type'] == 'PackTextDetInputs': + meta_keys += transform[ + 'meta_keys'] if 'meta_keys' in transform else [] + transform['meta_keys'] = list(set(meta_keys)) + transform['keys'] = ['img'] + transforms[i]['type'] = 'Collect' + if transform['type'] == 'Resize': + transforms[i]['size'] = transforms[i]['scale'] - def get_postprocess(self) -> Dict: + data_preprocessor = model_cfg.model.data_preprocessor + transforms.insert(-1, dict(type='DefaultFormatBundle')) + transforms.insert( + -2, + dict( + type='Pad', + size_divisor=data_preprocessor.get('pad_size_divisor', 1))) + transforms.insert( + -3, + dict( + type='Normalize', + to_rgb=data_preprocessor.get('bgr_to_rgb', False), + mean=data_preprocessor.get('mean', [0, 0, 0]), + std=data_preprocessor.get('std', [1, 1, 1]))) + return transforms + + def get_postprocess(self, *args, **kwargs) -> Dict: """Get the postprocess information for SDK. Return: @@ -232,7 +265,7 @@ class TextDetection(BaseTask): postprocess = self.model_cfg.model.det_head return postprocess - def get_model_name(self) -> str: + def get_model_name(self, *args, **kwargs) -> str: """Get the model name. Return: diff --git a/mmdeploy/codebase/mmocr/deploy/text_recognition.py b/mmdeploy/codebase/mmocr/deploy/text_recognition.py index dabfce7ce..9f2a6286e 100644 --- a/mmdeploy/codebase/mmocr/deploy/text_recognition.py +++ b/mmdeploy/codebase/mmocr/deploy/text_recognition.py @@ -218,7 +218,7 @@ class TextRecognition(BaseTask): """ raise NotImplementedError('Not supported yet.') - def get_preprocess(self) -> Dict: + def get_preprocess(self, *args, **kwargs) -> Dict: """Get the preprocess information for SDK. Return: @@ -226,21 +226,62 @@ class TextRecognition(BaseTask): """ input_shape = get_input_shape(self.deploy_cfg) model_cfg = process_model_config(self.model_cfg, [''], input_shape) - preprocess = model_cfg.test_dataloader.dataset.pipeline - return preprocess + pipeline = model_cfg.test_dataloader.dataset.pipeline + meta_keys = [ + 'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape', + 'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg', + 'valid_ratio' + ] + transforms = [ + item for item in pipeline if 'Random' not in item['type'] + and 'Annotation' not in item['type'] + ] + for i, transform in enumerate(transforms): + if transform['type'] == 'PackTextRecogInputs': + meta_keys += transform[ + 'meta_keys'] if 'meta_keys' in transform else [] + transform['meta_keys'] = list(set(meta_keys)) + transform['keys'] = ['img'] + transforms[i]['type'] = 'Collect' + if transform['type'] == 'Resize': + transforms[i]['size'] = transforms[i]['scale'] - def get_postprocess(self) -> Dict: + data_preprocessor = model_cfg.model.data_preprocessor + transforms.insert(-1, dict(type='DefaultFormatBundle')) + transforms.insert( + -2, + dict( + type='Pad', + size_divisor=data_preprocessor.get('pad_size_divisor', 1))) + transforms.insert( + -3, + dict( + type='Normalize', + to_rgb=data_preprocessor.get('bgr_to_rgb', False), + mean=data_preprocessor.get('mean', [0, 0, 0]), + std=data_preprocessor.get('std', [1, 1, 1]))) + return transforms + + def get_postprocess(self, + work_dir: Optional[str] = None, + **kwargs) -> Dict: """Get the postprocess information for SDK. Return: - dict: Composed of the postprocess information. + Dict: Composed of the postprocess information. """ postprocess = self.model_cfg.model.decoder.postprocessor if postprocess.type == 'CTCPostProcessor': postprocess.type = 'CTCConvertor' + import shutil + shutil.copy(self.model_cfg.dictionary.dict_file, + f'{work_dir}/dict_file.txt') + with_padding = self.model_cfg.dictionary.get('with_padding', False) + params = dict(dict_file='dict_file.txt', with_padding=with_padding) + postprocess['params'] = params return postprocess - def get_model_name(self) -> str: + def get_model_name(self, *args, **kwargs) -> str: """Get the model name. Return: diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection.py b/mmdeploy/codebase/mmpose/deploy/pose_detection.py index 587bc7ecc..f8b1189ec 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection.py @@ -305,7 +305,7 @@ class PoseDetection(BaseTask): for k, v in sorted(results.items()): logger.info(f'{k}: {v:.4f}') - def get_model_name(self) -> str: + def get_model_name(self, *args, **kwargs) -> str: """Get the model name. Return: @@ -324,7 +324,7 @@ class PoseDetection(BaseTask): """ raise NotImplementedError('Not supported yet.') - def get_preprocess(self) -> Dict: + def get_preprocess(self, *args, **kwargs) -> Dict: """Get the preprocess information for SDK. Return: @@ -335,7 +335,7 @@ class PoseDetection(BaseTask): preprocess = model_cfg.data.test.pipeline return preprocess - def get_postprocess(self) -> Dict: + def get_postprocess(self, *args, **kwargs) -> Dict: """Get the postprocess information for SDK.""" postprocess = {'type': 'UNKNOWN'} if self.model_cfg.model.type == 'TopDown': diff --git a/mmdeploy/codebase/mmrotate/deploy/rotated_detection.py b/mmdeploy/codebase/mmrotate/deploy/rotated_detection.py index db9e3edae..fe32d3e7f 100644 --- a/mmdeploy/codebase/mmrotate/deploy/rotated_detection.py +++ b/mmdeploy/codebase/mmrotate/deploy/rotated_detection.py @@ -335,7 +335,7 @@ class RotatedDetection(BaseTask): eval_kwargs.update(dict(metric=metrics, **kwargs)) logger.info(dataset.evaluate(outputs, **eval_kwargs)) - def get_preprocess(self) -> Dict: + def get_preprocess(self, *args, **kwargs) -> Dict: """Get the preprocess information for SDK. Return: @@ -349,7 +349,7 @@ class RotatedDetection(BaseTask): preprocess = model_cfg.data.test.pipeline return preprocess - def get_postprocess(self) -> Dict: + def get_postprocess(self, *args, **kwargs) -> Dict: """Get the postprocess information for SDK. Return: @@ -358,7 +358,7 @@ class RotatedDetection(BaseTask): postprocess = self.model_cfg.model.test_cfg return postprocess - def get_model_name(self) -> str: + def get_model_name(self, *args, **kwargs) -> str: """Get the model name. Return: diff --git a/mmdeploy/codebase/mmseg/deploy/segmentation.py b/mmdeploy/codebase/mmseg/deploy/segmentation.py index 678d23fac..6c40f4d06 100644 --- a/mmdeploy/codebase/mmseg/deploy/segmentation.py +++ b/mmdeploy/codebase/mmseg/deploy/segmentation.py @@ -241,7 +241,7 @@ class Segmentation(BaseTask): def get_partition_cfg(partition_type: str) -> Dict: raise NotImplementedError('Not supported yet.') - def get_preprocess(self) -> Dict: + def get_preprocess(self, *args, **kwargs) -> Dict: """Get the preprocess information for SDK. Return: @@ -272,7 +272,7 @@ class Segmentation(BaseTask): ])) return preprocess - def get_postprocess(self) -> Dict: + def get_postprocess(self, *args, **kwargs) -> Dict: """Get the postprocess information for SDK. Return: @@ -281,7 +281,7 @@ class Segmentation(BaseTask): postprocess = self.model_cfg.model.decode_head return postprocess - def get_model_name(self) -> str: + def get_model_name(self, *args, **kwargs) -> str: """Get the model name. Return: diff --git a/mmdeploy/utils/__init__.py b/mmdeploy/utils/__init__.py index dda45c09d..0fb74fb22 100644 --- a/mmdeploy/utils/__init__.py +++ b/mmdeploy/utils/__init__.py @@ -22,7 +22,7 @@ if importlib.util.find_spec('mmcv') is not None: get_dynamic_axes, get_input_shape, get_ir_config, get_model_inputs, get_onnx_config, get_partition_config, - get_task_type, is_dynamic_batch, + get_precision, get_task_type, is_dynamic_batch, is_dynamic_shape, load_config) # yapf: enable @@ -33,5 +33,5 @@ if importlib.util.find_spec('mmcv') is not None: 'get_codebase_config', 'get_common_config', 'get_dynamic_axes', 'get_input_shape', 'get_ir_config', 'get_model_inputs', 'get_onnx_config', 'get_partition_config', 'get_task_type', - 'is_dynamic_batch', 'is_dynamic_shape', 'load_config' + 'is_dynamic_batch', 'is_dynamic_shape', 'load_config', 'get_precision' ] diff --git a/mmdeploy/utils/config_utils.py b/mmdeploy/utils/config_utils.py index b57228076..510e22f08 100644 --- a/mmdeploy/utils/config_utils.py +++ b/mmdeploy/utils/config_utils.py @@ -380,3 +380,26 @@ def get_dynamic_axes( raise KeyError('No names were found to define dynamic axes.') dynamic_axes = dict(zip(axes_names, dynamic_axes)) return dynamic_axes + + +def get_precision(deploy_cfg: Union[str, mmengine.Config]) -> str: + """Get precision of config. + + Args: + deploy_cfg (str | mmengine.Config): The path or content of config. + + Returns: + str: The precision of target backend. + """ + precision = 'FP32' + deploy_cfg = load_config(deploy_cfg)[0] + backend = get_backend(deploy_cfg=deploy_cfg) + if backend == Backend.TENSORRT: + common_cfg = get_common_config(deploy_cfg) + if common_cfg.get('fp16_mode', False): + precision = 'FP16' + if common_cfg.get('int8_mode', False): + precision = 'INT8' + if backend == Backend.NCNN and 'precision' in deploy_cfg['backend_config']: + precision = deploy_cfg['backend_config']['precision'] + return precision