refactor dump-info (#1040)

pull/1070/head^2
AllentDan 2022-09-19 17:15:21 +08:00 committed by GitHub
parent 97e0d1228f
commit 633d74fb36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 171 additions and 185 deletions

View File

@ -6,9 +6,9 @@ from typing import Dict, List, Tuple, Union
import mmengine
from mmdeploy.apis import build_task_processor
from mmdeploy.utils import (Backend, Task, get_backend, get_codebase,
get_common_config, get_ir_config, get_root_logger,
get_task_type, is_dynamic_batch, load_config)
from mmdeploy.utils import (Task, get_backend, get_codebase, get_ir_config,
get_precision, get_root_logger, get_task_type,
is_dynamic_batch, load_config)
from mmdeploy.utils.constants import SDK_TASK_MAP as task_map
@ -79,60 +79,31 @@ def get_models(deploy_cfg: Union[str, mmengine.Config],
name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir)
precision = 'FP32'
ir_name = get_ir_config(deploy_cfg)['save_file']
net = ir_name
weights = ''
backend = get_backend(deploy_cfg=deploy_cfg)
backend = get_backend(deploy_cfg=deploy_cfg).value
def replace_suffix(file_name: str, dst_suffix: str) -> str:
"""Replace the suffix to the destination one.
Args:
file_name (str): The file name to be operated.
dst_suffix (str): The destination suffix.
Return:
str: The file name of which the suffix has been replaced.
"""
return re.sub(r'\.[a-z]+', dst_suffix, file_name)
if backend == Backend.TENSORRT:
net = replace_suffix(ir_name, '.engine')
common_cfg = get_common_config(deploy_cfg)
fp16_mode = common_cfg.get('fp16_mode', False)
int8_mode = common_cfg.get('int8_mode', False)
if fp16_mode:
precision = 'FP16'
if int8_mode:
precision = 'INT8'
elif backend == Backend.PPLNN:
precision = 'FP16'
weights = replace_suffix(ir_name, '.json')
net = ir_name
elif backend == Backend.OPENVINO:
net = replace_suffix(ir_name, '.xml')
weights = replace_suffix(ir_name, '.bin')
elif backend == Backend.NCNN:
net = replace_suffix(ir_name, '.param')
weights = replace_suffix(ir_name, '.bin')
if 'precision' in deploy_cfg['backend_config']:
precision = deploy_cfg['backend_config']['precision']
elif backend == Backend.SNPE:
net = replace_suffix(ir_name, '.dlc')
elif backend in [Backend.ONNXRUNTIME, Backend.TORCHSCRIPT]:
pass
else:
raise NotImplementedError(f'Not supported backend: {backend.value}.')
backend_net = dict(
tensorrt=lambda file: re.sub(r'\.[a-z]+', '.engine', file),
openvino=lambda file: re.sub(r'\.[a-z]+', '.xml', file),
ncnn=lambda file: re.sub(r'\.[a-z]+', '.param', file),
snpe=lambda file: re.sub(r'\.[a-z]+', '.dlc', file))
backend_weights = dict(
pplnn=lambda file: re.sub(r'\.[a-z]+', '.json', file),
openvino=lambda file: re.sub(r'\.[a-z]+', '.bin', file),
ncnn=lambda file: re.sub(r'\.[a-z]+', '.bin', file))
net = backend_net.get(backend, lambda x: x)(ir_name)
weights = backend_weights.get(backend, lambda x: weights)(ir_name)
precision = get_precision(deploy_cfg)
dynamic_shape = is_dynamic_batch(deploy_cfg, input_name='input')
batch_size = 1
return [
dict(
name=name,
net=net,
weights=weights,
backend=backend.value,
backend=backend,
precision=precision,
batch_size=batch_size,
batch_size=1,
dynamic_shape=dynamic_shape)
]
@ -151,20 +122,16 @@ def get_inference_info(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config,
input_map.
"""
name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir)
type = 'Task'
module = 'Net'
input = ['prep_output']
output = ['infer_output']
ir_config = get_ir_config(deploy_cfg)
input_names = ir_config.get('input_names', None)
input_name = input_names[0] if input_names else 'input'
input_map = dict(img=input_name)
return_dict = dict(
name=name,
type=type,
module=module,
input=input,
output=output,
type='Task',
module='Net',
input=['prep_output'],
output=['infer_output'],
input_map=input_map)
if 'use_vulkan' in deploy_cfg['backend_config']:
return_dict['use_vulkan'] = deploy_cfg['backend_config']['use_vulkan']
@ -174,71 +141,15 @@ def get_inference_info(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config,
def get_preprocess(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config):
task_processor = build_task_processor(
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device='cpu')
pipeline = task_processor.get_preprocess()
type = 'Task'
module = 'Transform'
name = 'Preprocess'
input = ['img']
output = ['prep_output']
meta_keys = [
'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg', 'valid_ratio'
]
transforms = [
item for item in pipeline
if 'Random' not in item['type'] and 'RescaleToZeroOne' not in
item['type'] and 'Annotation' not in item['type']
]
for i, transform in enumerate(transforms):
if 'keys' in transform and transform['keys'] == ['lq']:
transform['keys'] = ['img']
if 'key' in transform and transform['key'] == 'lq':
transform['key'] = 'img'
if transform['type'] == 'ResizeEdge':
transform['type'] = 'Resize'
transform['keep_ratio'] = True
# now the sdk of class has bugs, because ResizeEdge not implement
# in sdk.
transform['size'] = (transform['scale'], transform['scale'])
if transform['type'] in ('PackTextDetInputs', 'PackTextRecogInputs'):
meta_keys += transform[
'meta_keys'] if 'meta_keys' in transform else []
transform['meta_keys'] = list(set(meta_keys))
transform['keys'] = ['img']
transforms[i]['type'] = 'Collect'
if transform['type'] == 'PackDetInputs' or \
transform['type'] == 'PackClsInputs':
transforms.insert(i, dict(type='DefaultFormatBundle'))
transform['type'] = 'Collect'
if 'keys' not in transform:
transform['keys'] = ['img']
if transform['type'] == 'Resize':
transforms[i]['size'] = transforms[i]['scale']
data_preprocessor = model_cfg.model.data_preprocessor
transforms.insert(-1, dict(type='DefaultFormatBundle'))
transforms.insert(
-2,
dict(
type='Pad',
size_divisor=data_preprocessor.get('pad_size_divisor', 1)))
transforms.insert(
-3,
dict(
type='Normalize',
to_rgb=data_preprocessor.get('bgr_to_rgb', False),
mean=data_preprocessor.get('mean', [0, 0, 0]),
std=data_preprocessor.get('std', [1, 1, 1])))
assert transforms[0]['type'] == 'LoadImageFromFile', 'The first item type'\
' of pipeline should be LoadImageFromFile'
transforms = task_processor.get_preprocess()
assert transforms[0]['type'] == 'LoadImageFromFile', 'The first item'\
' type of pipeline should be LoadImageFromFile'
return dict(
type=type,
module=module,
name=name,
input=input,
output=output,
type='Task',
module='Transform',
name='Preprocess',
input=['img'],
output=['prep_output'],
transforms=transforms)
@ -255,38 +166,17 @@ def get_postprocess(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config,
dict: Composed of the model name, type, module, input, params and
output.
"""
module = get_codebase(deploy_cfg).value
type = 'Task'
name = 'postprocess'
params = dict()
task = get_task_type(deploy_cfg)
task_processor = build_task_processor(
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device='cpu')
post_processor = task_processor.get_postprocess()
post_processor = task_processor.get_postprocess(work_dir)
# TODO remove after adding instance segmentation to task processor
if task == Task.OBJECT_DETECTION and 'mask_thr_binary' in post_processor:
task = Task.INSTANCE_SEGMENTATION
component = task_map[task]['component']
if task not in (Task.SUPER_RESOLUTION, Task.SEGMENTATION):
if 'type' in post_processor:
component = post_processor.pop('type')
output = ['post_output']
if task == Task.TEXT_RECOGNITION:
import shutil
shutil.copy(model_cfg.dictionary.dict_file,
f'{work_dir}/dict_file.txt')
with_padding = model_cfg.dictionary.get('with_padding', False)
params = dict(dict_file='dict_file.txt', with_padding=with_padding)
return dict(
type=type,
module=module,
name=name,
component=component,
params=params,
output=output)
type='Task',
module=get_codebase(deploy_cfg).value,
name='postprocess',
component=post_processor['type'],
params=post_processor.get('params', dict()),
output=['post_output'])
def get_deploy(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config,
@ -384,7 +274,6 @@ def export2SDK(deploy_cfg: Union[str, mmengine.Config],
pth (str): The path of the model checkpoint weights.
"""
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
print(f'debugging what is type(model_cfg): {type(model_cfg)}')
deploy_info = get_deploy(deploy_cfg, model_cfg, work_dir=work_dir)
pipeline_info = get_pipeline(deploy_cfg, model_cfg, work_dir=work_dir)
detail_info = get_detail(deploy_cfg, model_cfg, pth=pth)

View File

@ -322,7 +322,7 @@ class BaseTask(metaclass=ABCMeta):
return input_data['inputs']
@abstractmethod
def get_preprocess(self) -> Dict:
def get_preprocess(self, *args, **kwargs) -> Dict:
"""Get the preprocess information for SDK.
Return:
@ -331,7 +331,7 @@ class BaseTask(metaclass=ABCMeta):
pass
@abstractmethod
def get_postprocess(self) -> Dict:
def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get the postprocess information for SDK.
Return:
@ -340,7 +340,7 @@ class BaseTask(metaclass=ABCMeta):
pass
@abstractmethod
def get_model_name(self) -> str:
def get_model_name(self, *args, **kwargs) -> str:
"""Get the model name.
Return:

View File

@ -241,7 +241,7 @@ class Classification(BaseTask):
"""
raise NotImplementedError('Not supported yet.')
def get_preprocess(self) -> Dict:
def get_preprocess(self, *args, **kwargs) -> Dict:
"""Get the preprocess information for SDK.
Return:
@ -252,7 +252,7 @@ class Classification(BaseTask):
preprocess = cfg.test_pipeline
return preprocess
def get_postprocess(self) -> Dict:
def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get the postprocess information for SDK.
Return:
@ -269,7 +269,7 @@ class Classification(BaseTask):
postprocess.topk = max(topk)
return postprocess
def get_model_name(self) -> str:
def get_model_name(self, *args, **kwargs) -> str:
"""Get the model name.
Return:

View File

@ -206,7 +206,7 @@ class ObjectDetection(BaseTask):
f'Unknown partition_type {partition_type}'
return MMDET_PARTITION_CFG[partition_type]
def get_preprocess(self) -> Dict:
def get_preprocess(self, *args, **kwargs) -> Dict:
"""Get the preprocess information for SDK.
Return:
@ -217,7 +217,7 @@ class ObjectDetection(BaseTask):
preprocess = model_cfg.test_pipeline
return preprocess
def get_postprocess(self) -> Dict:
def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get the postprocess information for SDK.
Return:
@ -233,7 +233,7 @@ class ObjectDetection(BaseTask):
'mask_thr_binary']
return postprocess
def get_model_name(self) -> str:
def get_model_name(self, *args, **kwargs) -> str:
"""Get the model name.
Return:

View File

@ -199,7 +199,7 @@ class VoxelDetection(BaseTask):
eval_kwargs.update(dict(metric=metrics, **kwargs))
dataset.evaluate(outputs, **eval_kwargs)
def get_model_name(self) -> str:
def get_model_name(self, *args, **kwargs) -> str:
"""Get the model name.
Return:
@ -230,7 +230,7 @@ class VoxelDetection(BaseTask):
"""
raise NotImplementedError
def get_postprocess(self) -> Dict:
def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get the postprocess information for SDK.
Return:
@ -238,7 +238,7 @@ class VoxelDetection(BaseTask):
"""
raise NotImplementedError
def get_preprocess(self) -> Dict:
def get_preprocess(self, *args, **kwargs) -> Dict:
"""Get the preprocess information for SDK.
Return:

View File

@ -286,7 +286,7 @@ class SuperResolution(BaseTask):
for stat in stats:
logger.info('Eval-{}: {}'.format(stat, stats[stat]))
def get_preprocess(self) -> Dict:
def get_preprocess(self, *args, **kwargs) -> Dict:
"""Get the preprocess information for SDK.
Return:
@ -300,7 +300,7 @@ class SuperResolution(BaseTask):
item['std'] = [255, 255, 255]
return preprocess
def get_postprocess(self) -> Dict:
def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get the postprocess information for SDK.
Return:
@ -308,7 +308,7 @@ class SuperResolution(BaseTask):
"""
return dict()
def get_model_name(self) -> str:
def get_model_name(self, *args, **kwargs) -> str:
"""Get the model name.
Return:

View File

@ -212,7 +212,7 @@ class TextDetection(BaseTask):
"""
raise NotImplementedError('Not supported yet.')
def get_preprocess(self) -> Dict:
def get_preprocess(self, *args, **kwargs) -> Dict:
"""Get the preprocess information for SDK.
Return:
@ -220,10 +220,43 @@ class TextDetection(BaseTask):
"""
input_shape = get_input_shape(self.deploy_cfg)
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
preprocess = model_cfg.test_dataloader.dataset.pipeline
return preprocess
pipeline = model_cfg.test_dataloader.dataset.pipeline
meta_keys = [
'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg',
'valid_ratio'
]
transforms = [
item for item in pipeline if 'Random' not in item['type']
and 'Annotation' not in item['type']
]
for i, transform in enumerate(transforms):
if transform['type'] == 'PackTextDetInputs':
meta_keys += transform[
'meta_keys'] if 'meta_keys' in transform else []
transform['meta_keys'] = list(set(meta_keys))
transform['keys'] = ['img']
transforms[i]['type'] = 'Collect'
if transform['type'] == 'Resize':
transforms[i]['size'] = transforms[i]['scale']
def get_postprocess(self) -> Dict:
data_preprocessor = model_cfg.model.data_preprocessor
transforms.insert(-1, dict(type='DefaultFormatBundle'))
transforms.insert(
-2,
dict(
type='Pad',
size_divisor=data_preprocessor.get('pad_size_divisor', 1)))
transforms.insert(
-3,
dict(
type='Normalize',
to_rgb=data_preprocessor.get('bgr_to_rgb', False),
mean=data_preprocessor.get('mean', [0, 0, 0]),
std=data_preprocessor.get('std', [1, 1, 1])))
return transforms
def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get the postprocess information for SDK.
Return:
@ -232,7 +265,7 @@ class TextDetection(BaseTask):
postprocess = self.model_cfg.model.det_head
return postprocess
def get_model_name(self) -> str:
def get_model_name(self, *args, **kwargs) -> str:
"""Get the model name.
Return:

View File

@ -218,7 +218,7 @@ class TextRecognition(BaseTask):
"""
raise NotImplementedError('Not supported yet.')
def get_preprocess(self) -> Dict:
def get_preprocess(self, *args, **kwargs) -> Dict:
"""Get the preprocess information for SDK.
Return:
@ -226,21 +226,62 @@ class TextRecognition(BaseTask):
"""
input_shape = get_input_shape(self.deploy_cfg)
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
preprocess = model_cfg.test_dataloader.dataset.pipeline
return preprocess
pipeline = model_cfg.test_dataloader.dataset.pipeline
meta_keys = [
'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg',
'valid_ratio'
]
transforms = [
item for item in pipeline if 'Random' not in item['type']
and 'Annotation' not in item['type']
]
for i, transform in enumerate(transforms):
if transform['type'] == 'PackTextRecogInputs':
meta_keys += transform[
'meta_keys'] if 'meta_keys' in transform else []
transform['meta_keys'] = list(set(meta_keys))
transform['keys'] = ['img']
transforms[i]['type'] = 'Collect'
if transform['type'] == 'Resize':
transforms[i]['size'] = transforms[i]['scale']
def get_postprocess(self) -> Dict:
data_preprocessor = model_cfg.model.data_preprocessor
transforms.insert(-1, dict(type='DefaultFormatBundle'))
transforms.insert(
-2,
dict(
type='Pad',
size_divisor=data_preprocessor.get('pad_size_divisor', 1)))
transforms.insert(
-3,
dict(
type='Normalize',
to_rgb=data_preprocessor.get('bgr_to_rgb', False),
mean=data_preprocessor.get('mean', [0, 0, 0]),
std=data_preprocessor.get('std', [1, 1, 1])))
return transforms
def get_postprocess(self,
work_dir: Optional[str] = None,
**kwargs) -> Dict:
"""Get the postprocess information for SDK.
Return:
dict: Composed of the postprocess information.
Dict: Composed of the postprocess information.
"""
postprocess = self.model_cfg.model.decoder.postprocessor
if postprocess.type == 'CTCPostProcessor':
postprocess.type = 'CTCConvertor'
import shutil
shutil.copy(self.model_cfg.dictionary.dict_file,
f'{work_dir}/dict_file.txt')
with_padding = self.model_cfg.dictionary.get('with_padding', False)
params = dict(dict_file='dict_file.txt', with_padding=with_padding)
postprocess['params'] = params
return postprocess
def get_model_name(self) -> str:
def get_model_name(self, *args, **kwargs) -> str:
"""Get the model name.
Return:

View File

@ -305,7 +305,7 @@ class PoseDetection(BaseTask):
for k, v in sorted(results.items()):
logger.info(f'{k}: {v:.4f}')
def get_model_name(self) -> str:
def get_model_name(self, *args, **kwargs) -> str:
"""Get the model name.
Return:
@ -324,7 +324,7 @@ class PoseDetection(BaseTask):
"""
raise NotImplementedError('Not supported yet.')
def get_preprocess(self) -> Dict:
def get_preprocess(self, *args, **kwargs) -> Dict:
"""Get the preprocess information for SDK.
Return:
@ -335,7 +335,7 @@ class PoseDetection(BaseTask):
preprocess = model_cfg.data.test.pipeline
return preprocess
def get_postprocess(self) -> Dict:
def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get the postprocess information for SDK."""
postprocess = {'type': 'UNKNOWN'}
if self.model_cfg.model.type == 'TopDown':

View File

@ -335,7 +335,7 @@ class RotatedDetection(BaseTask):
eval_kwargs.update(dict(metric=metrics, **kwargs))
logger.info(dataset.evaluate(outputs, **eval_kwargs))
def get_preprocess(self) -> Dict:
def get_preprocess(self, *args, **kwargs) -> Dict:
"""Get the preprocess information for SDK.
Return:
@ -349,7 +349,7 @@ class RotatedDetection(BaseTask):
preprocess = model_cfg.data.test.pipeline
return preprocess
def get_postprocess(self) -> Dict:
def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get the postprocess information for SDK.
Return:
@ -358,7 +358,7 @@ class RotatedDetection(BaseTask):
postprocess = self.model_cfg.model.test_cfg
return postprocess
def get_model_name(self) -> str:
def get_model_name(self, *args, **kwargs) -> str:
"""Get the model name.
Return:

View File

@ -241,7 +241,7 @@ class Segmentation(BaseTask):
def get_partition_cfg(partition_type: str) -> Dict:
raise NotImplementedError('Not supported yet.')
def get_preprocess(self) -> Dict:
def get_preprocess(self, *args, **kwargs) -> Dict:
"""Get the preprocess information for SDK.
Return:
@ -272,7 +272,7 @@ class Segmentation(BaseTask):
]))
return preprocess
def get_postprocess(self) -> Dict:
def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get the postprocess information for SDK.
Return:
@ -281,7 +281,7 @@ class Segmentation(BaseTask):
postprocess = self.model_cfg.model.decode_head
return postprocess
def get_model_name(self) -> str:
def get_model_name(self, *args, **kwargs) -> str:
"""Get the model name.
Return:

View File

@ -22,7 +22,7 @@ if importlib.util.find_spec('mmcv') is not None:
get_dynamic_axes, get_input_shape,
get_ir_config, get_model_inputs,
get_onnx_config, get_partition_config,
get_task_type, is_dynamic_batch,
get_precision, get_task_type, is_dynamic_batch,
is_dynamic_shape, load_config)
# yapf: enable
@ -33,5 +33,5 @@ if importlib.util.find_spec('mmcv') is not None:
'get_codebase_config', 'get_common_config', 'get_dynamic_axes',
'get_input_shape', 'get_ir_config', 'get_model_inputs',
'get_onnx_config', 'get_partition_config', 'get_task_type',
'is_dynamic_batch', 'is_dynamic_shape', 'load_config'
'is_dynamic_batch', 'is_dynamic_shape', 'load_config', 'get_precision'
]

View File

@ -380,3 +380,26 @@ def get_dynamic_axes(
raise KeyError('No names were found to define dynamic axes.')
dynamic_axes = dict(zip(axes_names, dynamic_axes))
return dynamic_axes
def get_precision(deploy_cfg: Union[str, mmengine.Config]) -> str:
"""Get precision of config.
Args:
deploy_cfg (str | mmengine.Config): The path or content of config.
Returns:
str: The precision of target backend.
"""
precision = 'FP32'
deploy_cfg = load_config(deploy_cfg)[0]
backend = get_backend(deploy_cfg=deploy_cfg)
if backend == Backend.TENSORRT:
common_cfg = get_common_config(deploy_cfg)
if common_cfg.get('fp16_mode', False):
precision = 'FP16'
if common_cfg.get('int8_mode', False):
precision = 'INT8'
if backend == Backend.NCNN and 'precision' in deploy_cfg['backend_config']:
precision = deploy_cfg['backend_config']['precision']
return precision