refactor dump-info (#1040)
parent
97e0d1228f
commit
633d74fb36
|
@ -6,9 +6,9 @@ from typing import Dict, List, Tuple, Union
|
|||
import mmengine
|
||||
|
||||
from mmdeploy.apis import build_task_processor
|
||||
from mmdeploy.utils import (Backend, Task, get_backend, get_codebase,
|
||||
get_common_config, get_ir_config, get_root_logger,
|
||||
get_task_type, is_dynamic_batch, load_config)
|
||||
from mmdeploy.utils import (Task, get_backend, get_codebase, get_ir_config,
|
||||
get_precision, get_root_logger, get_task_type,
|
||||
is_dynamic_batch, load_config)
|
||||
from mmdeploy.utils.constants import SDK_TASK_MAP as task_map
|
||||
|
||||
|
||||
|
@ -79,60 +79,31 @@ def get_models(deploy_cfg: Union[str, mmengine.Config],
|
|||
name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir)
|
||||
precision = 'FP32'
|
||||
ir_name = get_ir_config(deploy_cfg)['save_file']
|
||||
net = ir_name
|
||||
weights = ''
|
||||
backend = get_backend(deploy_cfg=deploy_cfg)
|
||||
backend = get_backend(deploy_cfg=deploy_cfg).value
|
||||
|
||||
def replace_suffix(file_name: str, dst_suffix: str) -> str:
|
||||
"""Replace the suffix to the destination one.
|
||||
|
||||
Args:
|
||||
file_name (str): The file name to be operated.
|
||||
dst_suffix (str): The destination suffix.
|
||||
|
||||
Return:
|
||||
str: The file name of which the suffix has been replaced.
|
||||
"""
|
||||
return re.sub(r'\.[a-z]+', dst_suffix, file_name)
|
||||
|
||||
if backend == Backend.TENSORRT:
|
||||
net = replace_suffix(ir_name, '.engine')
|
||||
common_cfg = get_common_config(deploy_cfg)
|
||||
fp16_mode = common_cfg.get('fp16_mode', False)
|
||||
int8_mode = common_cfg.get('int8_mode', False)
|
||||
if fp16_mode:
|
||||
precision = 'FP16'
|
||||
if int8_mode:
|
||||
precision = 'INT8'
|
||||
elif backend == Backend.PPLNN:
|
||||
precision = 'FP16'
|
||||
weights = replace_suffix(ir_name, '.json')
|
||||
net = ir_name
|
||||
elif backend == Backend.OPENVINO:
|
||||
net = replace_suffix(ir_name, '.xml')
|
||||
weights = replace_suffix(ir_name, '.bin')
|
||||
elif backend == Backend.NCNN:
|
||||
net = replace_suffix(ir_name, '.param')
|
||||
weights = replace_suffix(ir_name, '.bin')
|
||||
if 'precision' in deploy_cfg['backend_config']:
|
||||
precision = deploy_cfg['backend_config']['precision']
|
||||
elif backend == Backend.SNPE:
|
||||
net = replace_suffix(ir_name, '.dlc')
|
||||
elif backend in [Backend.ONNXRUNTIME, Backend.TORCHSCRIPT]:
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(f'Not supported backend: {backend.value}.')
|
||||
backend_net = dict(
|
||||
tensorrt=lambda file: re.sub(r'\.[a-z]+', '.engine', file),
|
||||
openvino=lambda file: re.sub(r'\.[a-z]+', '.xml', file),
|
||||
ncnn=lambda file: re.sub(r'\.[a-z]+', '.param', file),
|
||||
snpe=lambda file: re.sub(r'\.[a-z]+', '.dlc', file))
|
||||
backend_weights = dict(
|
||||
pplnn=lambda file: re.sub(r'\.[a-z]+', '.json', file),
|
||||
openvino=lambda file: re.sub(r'\.[a-z]+', '.bin', file),
|
||||
ncnn=lambda file: re.sub(r'\.[a-z]+', '.bin', file))
|
||||
net = backend_net.get(backend, lambda x: x)(ir_name)
|
||||
weights = backend_weights.get(backend, lambda x: weights)(ir_name)
|
||||
|
||||
precision = get_precision(deploy_cfg)
|
||||
dynamic_shape = is_dynamic_batch(deploy_cfg, input_name='input')
|
||||
batch_size = 1
|
||||
return [
|
||||
dict(
|
||||
name=name,
|
||||
net=net,
|
||||
weights=weights,
|
||||
backend=backend.value,
|
||||
backend=backend,
|
||||
precision=precision,
|
||||
batch_size=batch_size,
|
||||
batch_size=1,
|
||||
dynamic_shape=dynamic_shape)
|
||||
]
|
||||
|
||||
|
@ -151,20 +122,16 @@ def get_inference_info(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config,
|
|||
input_map.
|
||||
"""
|
||||
name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir)
|
||||
type = 'Task'
|
||||
module = 'Net'
|
||||
input = ['prep_output']
|
||||
output = ['infer_output']
|
||||
ir_config = get_ir_config(deploy_cfg)
|
||||
input_names = ir_config.get('input_names', None)
|
||||
input_name = input_names[0] if input_names else 'input'
|
||||
input_map = dict(img=input_name)
|
||||
return_dict = dict(
|
||||
name=name,
|
||||
type=type,
|
||||
module=module,
|
||||
input=input,
|
||||
output=output,
|
||||
type='Task',
|
||||
module='Net',
|
||||
input=['prep_output'],
|
||||
output=['infer_output'],
|
||||
input_map=input_map)
|
||||
if 'use_vulkan' in deploy_cfg['backend_config']:
|
||||
return_dict['use_vulkan'] = deploy_cfg['backend_config']['use_vulkan']
|
||||
|
@ -174,71 +141,15 @@ def get_inference_info(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config,
|
|||
def get_preprocess(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config):
|
||||
task_processor = build_task_processor(
|
||||
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device='cpu')
|
||||
pipeline = task_processor.get_preprocess()
|
||||
type = 'Task'
|
||||
module = 'Transform'
|
||||
name = 'Preprocess'
|
||||
input = ['img']
|
||||
output = ['prep_output']
|
||||
meta_keys = [
|
||||
'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',
|
||||
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg', 'valid_ratio'
|
||||
]
|
||||
transforms = [
|
||||
item for item in pipeline
|
||||
if 'Random' not in item['type'] and 'RescaleToZeroOne' not in
|
||||
item['type'] and 'Annotation' not in item['type']
|
||||
]
|
||||
for i, transform in enumerate(transforms):
|
||||
if 'keys' in transform and transform['keys'] == ['lq']:
|
||||
transform['keys'] = ['img']
|
||||
if 'key' in transform and transform['key'] == 'lq':
|
||||
transform['key'] = 'img'
|
||||
if transform['type'] == 'ResizeEdge':
|
||||
transform['type'] = 'Resize'
|
||||
transform['keep_ratio'] = True
|
||||
# now the sdk of class has bugs, because ResizeEdge not implement
|
||||
# in sdk.
|
||||
transform['size'] = (transform['scale'], transform['scale'])
|
||||
if transform['type'] in ('PackTextDetInputs', 'PackTextRecogInputs'):
|
||||
meta_keys += transform[
|
||||
'meta_keys'] if 'meta_keys' in transform else []
|
||||
transform['meta_keys'] = list(set(meta_keys))
|
||||
transform['keys'] = ['img']
|
||||
transforms[i]['type'] = 'Collect'
|
||||
if transform['type'] == 'PackDetInputs' or \
|
||||
transform['type'] == 'PackClsInputs':
|
||||
transforms.insert(i, dict(type='DefaultFormatBundle'))
|
||||
transform['type'] = 'Collect'
|
||||
if 'keys' not in transform:
|
||||
transform['keys'] = ['img']
|
||||
if transform['type'] == 'Resize':
|
||||
transforms[i]['size'] = transforms[i]['scale']
|
||||
|
||||
data_preprocessor = model_cfg.model.data_preprocessor
|
||||
transforms.insert(-1, dict(type='DefaultFormatBundle'))
|
||||
transforms.insert(
|
||||
-2,
|
||||
dict(
|
||||
type='Pad',
|
||||
size_divisor=data_preprocessor.get('pad_size_divisor', 1)))
|
||||
transforms.insert(
|
||||
-3,
|
||||
dict(
|
||||
type='Normalize',
|
||||
to_rgb=data_preprocessor.get('bgr_to_rgb', False),
|
||||
mean=data_preprocessor.get('mean', [0, 0, 0]),
|
||||
std=data_preprocessor.get('std', [1, 1, 1])))
|
||||
|
||||
assert transforms[0]['type'] == 'LoadImageFromFile', 'The first item type'\
|
||||
' of pipeline should be LoadImageFromFile'
|
||||
|
||||
transforms = task_processor.get_preprocess()
|
||||
assert transforms[0]['type'] == 'LoadImageFromFile', 'The first item'\
|
||||
' type of pipeline should be LoadImageFromFile'
|
||||
return dict(
|
||||
type=type,
|
||||
module=module,
|
||||
name=name,
|
||||
input=input,
|
||||
output=output,
|
||||
type='Task',
|
||||
module='Transform',
|
||||
name='Preprocess',
|
||||
input=['img'],
|
||||
output=['prep_output'],
|
||||
transforms=transforms)
|
||||
|
||||
|
||||
|
@ -255,38 +166,17 @@ def get_postprocess(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config,
|
|||
dict: Composed of the model name, type, module, input, params and
|
||||
output.
|
||||
"""
|
||||
module = get_codebase(deploy_cfg).value
|
||||
type = 'Task'
|
||||
name = 'postprocess'
|
||||
params = dict()
|
||||
task = get_task_type(deploy_cfg)
|
||||
task_processor = build_task_processor(
|
||||
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device='cpu')
|
||||
post_processor = task_processor.get_postprocess()
|
||||
post_processor = task_processor.get_postprocess(work_dir)
|
||||
|
||||
# TODO remove after adding instance segmentation to task processor
|
||||
if task == Task.OBJECT_DETECTION and 'mask_thr_binary' in post_processor:
|
||||
task = Task.INSTANCE_SEGMENTATION
|
||||
|
||||
component = task_map[task]['component']
|
||||
if task not in (Task.SUPER_RESOLUTION, Task.SEGMENTATION):
|
||||
if 'type' in post_processor:
|
||||
component = post_processor.pop('type')
|
||||
output = ['post_output']
|
||||
|
||||
if task == Task.TEXT_RECOGNITION:
|
||||
import shutil
|
||||
shutil.copy(model_cfg.dictionary.dict_file,
|
||||
f'{work_dir}/dict_file.txt')
|
||||
with_padding = model_cfg.dictionary.get('with_padding', False)
|
||||
params = dict(dict_file='dict_file.txt', with_padding=with_padding)
|
||||
return dict(
|
||||
type=type,
|
||||
module=module,
|
||||
name=name,
|
||||
component=component,
|
||||
params=params,
|
||||
output=output)
|
||||
type='Task',
|
||||
module=get_codebase(deploy_cfg).value,
|
||||
name='postprocess',
|
||||
component=post_processor['type'],
|
||||
params=post_processor.get('params', dict()),
|
||||
output=['post_output'])
|
||||
|
||||
|
||||
def get_deploy(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config,
|
||||
|
@ -384,7 +274,6 @@ def export2SDK(deploy_cfg: Union[str, mmengine.Config],
|
|||
pth (str): The path of the model checkpoint weights.
|
||||
"""
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
||||
print(f'debugging what is type(model_cfg): {type(model_cfg)}')
|
||||
deploy_info = get_deploy(deploy_cfg, model_cfg, work_dir=work_dir)
|
||||
pipeline_info = get_pipeline(deploy_cfg, model_cfg, work_dir=work_dir)
|
||||
detail_info = get_detail(deploy_cfg, model_cfg, pth=pth)
|
||||
|
|
|
@ -322,7 +322,7 @@ class BaseTask(metaclass=ABCMeta):
|
|||
return input_data['inputs']
|
||||
|
||||
@abstractmethod
|
||||
def get_preprocess(self) -> Dict:
|
||||
def get_preprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the preprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -331,7 +331,7 @@ class BaseTask(metaclass=ABCMeta):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_postprocess(self) -> Dict:
|
||||
def get_postprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the postprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -340,7 +340,7 @@ class BaseTask(metaclass=ABCMeta):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_name(self) -> str:
|
||||
def get_model_name(self, *args, **kwargs) -> str:
|
||||
"""Get the model name.
|
||||
|
||||
Return:
|
||||
|
|
|
@ -241,7 +241,7 @@ class Classification(BaseTask):
|
|||
"""
|
||||
raise NotImplementedError('Not supported yet.')
|
||||
|
||||
def get_preprocess(self) -> Dict:
|
||||
def get_preprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the preprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -252,7 +252,7 @@ class Classification(BaseTask):
|
|||
preprocess = cfg.test_pipeline
|
||||
return preprocess
|
||||
|
||||
def get_postprocess(self) -> Dict:
|
||||
def get_postprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the postprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -269,7 +269,7 @@ class Classification(BaseTask):
|
|||
postprocess.topk = max(topk)
|
||||
return postprocess
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
def get_model_name(self, *args, **kwargs) -> str:
|
||||
"""Get the model name.
|
||||
|
||||
Return:
|
||||
|
|
|
@ -206,7 +206,7 @@ class ObjectDetection(BaseTask):
|
|||
f'Unknown partition_type {partition_type}'
|
||||
return MMDET_PARTITION_CFG[partition_type]
|
||||
|
||||
def get_preprocess(self) -> Dict:
|
||||
def get_preprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the preprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -217,7 +217,7 @@ class ObjectDetection(BaseTask):
|
|||
preprocess = model_cfg.test_pipeline
|
||||
return preprocess
|
||||
|
||||
def get_postprocess(self) -> Dict:
|
||||
def get_postprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the postprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -233,7 +233,7 @@ class ObjectDetection(BaseTask):
|
|||
'mask_thr_binary']
|
||||
return postprocess
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
def get_model_name(self, *args, **kwargs) -> str:
|
||||
"""Get the model name.
|
||||
|
||||
Return:
|
||||
|
|
|
@ -199,7 +199,7 @@ class VoxelDetection(BaseTask):
|
|||
eval_kwargs.update(dict(metric=metrics, **kwargs))
|
||||
dataset.evaluate(outputs, **eval_kwargs)
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
def get_model_name(self, *args, **kwargs) -> str:
|
||||
"""Get the model name.
|
||||
|
||||
Return:
|
||||
|
@ -230,7 +230,7 @@ class VoxelDetection(BaseTask):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_postprocess(self) -> Dict:
|
||||
def get_postprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the postprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -238,7 +238,7 @@ class VoxelDetection(BaseTask):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_preprocess(self) -> Dict:
|
||||
def get_preprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the preprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
|
|
@ -286,7 +286,7 @@ class SuperResolution(BaseTask):
|
|||
for stat in stats:
|
||||
logger.info('Eval-{}: {}'.format(stat, stats[stat]))
|
||||
|
||||
def get_preprocess(self) -> Dict:
|
||||
def get_preprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the preprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -300,7 +300,7 @@ class SuperResolution(BaseTask):
|
|||
item['std'] = [255, 255, 255]
|
||||
return preprocess
|
||||
|
||||
def get_postprocess(self) -> Dict:
|
||||
def get_postprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the postprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -308,7 +308,7 @@ class SuperResolution(BaseTask):
|
|||
"""
|
||||
return dict()
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
def get_model_name(self, *args, **kwargs) -> str:
|
||||
"""Get the model name.
|
||||
|
||||
Return:
|
||||
|
|
|
@ -212,7 +212,7 @@ class TextDetection(BaseTask):
|
|||
"""
|
||||
raise NotImplementedError('Not supported yet.')
|
||||
|
||||
def get_preprocess(self) -> Dict:
|
||||
def get_preprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the preprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -220,10 +220,43 @@ class TextDetection(BaseTask):
|
|||
"""
|
||||
input_shape = get_input_shape(self.deploy_cfg)
|
||||
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
|
||||
preprocess = model_cfg.test_dataloader.dataset.pipeline
|
||||
return preprocess
|
||||
pipeline = model_cfg.test_dataloader.dataset.pipeline
|
||||
meta_keys = [
|
||||
'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',
|
||||
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg',
|
||||
'valid_ratio'
|
||||
]
|
||||
transforms = [
|
||||
item for item in pipeline if 'Random' not in item['type']
|
||||
and 'Annotation' not in item['type']
|
||||
]
|
||||
for i, transform in enumerate(transforms):
|
||||
if transform['type'] == 'PackTextDetInputs':
|
||||
meta_keys += transform[
|
||||
'meta_keys'] if 'meta_keys' in transform else []
|
||||
transform['meta_keys'] = list(set(meta_keys))
|
||||
transform['keys'] = ['img']
|
||||
transforms[i]['type'] = 'Collect'
|
||||
if transform['type'] == 'Resize':
|
||||
transforms[i]['size'] = transforms[i]['scale']
|
||||
|
||||
def get_postprocess(self) -> Dict:
|
||||
data_preprocessor = model_cfg.model.data_preprocessor
|
||||
transforms.insert(-1, dict(type='DefaultFormatBundle'))
|
||||
transforms.insert(
|
||||
-2,
|
||||
dict(
|
||||
type='Pad',
|
||||
size_divisor=data_preprocessor.get('pad_size_divisor', 1)))
|
||||
transforms.insert(
|
||||
-3,
|
||||
dict(
|
||||
type='Normalize',
|
||||
to_rgb=data_preprocessor.get('bgr_to_rgb', False),
|
||||
mean=data_preprocessor.get('mean', [0, 0, 0]),
|
||||
std=data_preprocessor.get('std', [1, 1, 1])))
|
||||
return transforms
|
||||
|
||||
def get_postprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the postprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -232,7 +265,7 @@ class TextDetection(BaseTask):
|
|||
postprocess = self.model_cfg.model.det_head
|
||||
return postprocess
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
def get_model_name(self, *args, **kwargs) -> str:
|
||||
"""Get the model name.
|
||||
|
||||
Return:
|
||||
|
|
|
@ -218,7 +218,7 @@ class TextRecognition(BaseTask):
|
|||
"""
|
||||
raise NotImplementedError('Not supported yet.')
|
||||
|
||||
def get_preprocess(self) -> Dict:
|
||||
def get_preprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the preprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -226,21 +226,62 @@ class TextRecognition(BaseTask):
|
|||
"""
|
||||
input_shape = get_input_shape(self.deploy_cfg)
|
||||
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
|
||||
preprocess = model_cfg.test_dataloader.dataset.pipeline
|
||||
return preprocess
|
||||
pipeline = model_cfg.test_dataloader.dataset.pipeline
|
||||
meta_keys = [
|
||||
'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',
|
||||
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg',
|
||||
'valid_ratio'
|
||||
]
|
||||
transforms = [
|
||||
item for item in pipeline if 'Random' not in item['type']
|
||||
and 'Annotation' not in item['type']
|
||||
]
|
||||
for i, transform in enumerate(transforms):
|
||||
if transform['type'] == 'PackTextRecogInputs':
|
||||
meta_keys += transform[
|
||||
'meta_keys'] if 'meta_keys' in transform else []
|
||||
transform['meta_keys'] = list(set(meta_keys))
|
||||
transform['keys'] = ['img']
|
||||
transforms[i]['type'] = 'Collect'
|
||||
if transform['type'] == 'Resize':
|
||||
transforms[i]['size'] = transforms[i]['scale']
|
||||
|
||||
def get_postprocess(self) -> Dict:
|
||||
data_preprocessor = model_cfg.model.data_preprocessor
|
||||
transforms.insert(-1, dict(type='DefaultFormatBundle'))
|
||||
transforms.insert(
|
||||
-2,
|
||||
dict(
|
||||
type='Pad',
|
||||
size_divisor=data_preprocessor.get('pad_size_divisor', 1)))
|
||||
transforms.insert(
|
||||
-3,
|
||||
dict(
|
||||
type='Normalize',
|
||||
to_rgb=data_preprocessor.get('bgr_to_rgb', False),
|
||||
mean=data_preprocessor.get('mean', [0, 0, 0]),
|
||||
std=data_preprocessor.get('std', [1, 1, 1])))
|
||||
return transforms
|
||||
|
||||
def get_postprocess(self,
|
||||
work_dir: Optional[str] = None,
|
||||
**kwargs) -> Dict:
|
||||
"""Get the postprocess information for SDK.
|
||||
|
||||
Return:
|
||||
dict: Composed of the postprocess information.
|
||||
Dict: Composed of the postprocess information.
|
||||
"""
|
||||
postprocess = self.model_cfg.model.decoder.postprocessor
|
||||
if postprocess.type == 'CTCPostProcessor':
|
||||
postprocess.type = 'CTCConvertor'
|
||||
import shutil
|
||||
shutil.copy(self.model_cfg.dictionary.dict_file,
|
||||
f'{work_dir}/dict_file.txt')
|
||||
with_padding = self.model_cfg.dictionary.get('with_padding', False)
|
||||
params = dict(dict_file='dict_file.txt', with_padding=with_padding)
|
||||
postprocess['params'] = params
|
||||
return postprocess
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
def get_model_name(self, *args, **kwargs) -> str:
|
||||
"""Get the model name.
|
||||
|
||||
Return:
|
||||
|
|
|
@ -305,7 +305,7 @@ class PoseDetection(BaseTask):
|
|||
for k, v in sorted(results.items()):
|
||||
logger.info(f'{k}: {v:.4f}')
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
def get_model_name(self, *args, **kwargs) -> str:
|
||||
"""Get the model name.
|
||||
|
||||
Return:
|
||||
|
@ -324,7 +324,7 @@ class PoseDetection(BaseTask):
|
|||
"""
|
||||
raise NotImplementedError('Not supported yet.')
|
||||
|
||||
def get_preprocess(self) -> Dict:
|
||||
def get_preprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the preprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -335,7 +335,7 @@ class PoseDetection(BaseTask):
|
|||
preprocess = model_cfg.data.test.pipeline
|
||||
return preprocess
|
||||
|
||||
def get_postprocess(self) -> Dict:
|
||||
def get_postprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the postprocess information for SDK."""
|
||||
postprocess = {'type': 'UNKNOWN'}
|
||||
if self.model_cfg.model.type == 'TopDown':
|
||||
|
|
|
@ -335,7 +335,7 @@ class RotatedDetection(BaseTask):
|
|||
eval_kwargs.update(dict(metric=metrics, **kwargs))
|
||||
logger.info(dataset.evaluate(outputs, **eval_kwargs))
|
||||
|
||||
def get_preprocess(self) -> Dict:
|
||||
def get_preprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the preprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -349,7 +349,7 @@ class RotatedDetection(BaseTask):
|
|||
preprocess = model_cfg.data.test.pipeline
|
||||
return preprocess
|
||||
|
||||
def get_postprocess(self) -> Dict:
|
||||
def get_postprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the postprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -358,7 +358,7 @@ class RotatedDetection(BaseTask):
|
|||
postprocess = self.model_cfg.model.test_cfg
|
||||
return postprocess
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
def get_model_name(self, *args, **kwargs) -> str:
|
||||
"""Get the model name.
|
||||
|
||||
Return:
|
||||
|
|
|
@ -241,7 +241,7 @@ class Segmentation(BaseTask):
|
|||
def get_partition_cfg(partition_type: str) -> Dict:
|
||||
raise NotImplementedError('Not supported yet.')
|
||||
|
||||
def get_preprocess(self) -> Dict:
|
||||
def get_preprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the preprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -272,7 +272,7 @@ class Segmentation(BaseTask):
|
|||
]))
|
||||
return preprocess
|
||||
|
||||
def get_postprocess(self) -> Dict:
|
||||
def get_postprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the postprocess information for SDK.
|
||||
|
||||
Return:
|
||||
|
@ -281,7 +281,7 @@ class Segmentation(BaseTask):
|
|||
postprocess = self.model_cfg.model.decode_head
|
||||
return postprocess
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
def get_model_name(self, *args, **kwargs) -> str:
|
||||
"""Get the model name.
|
||||
|
||||
Return:
|
||||
|
|
|
@ -22,7 +22,7 @@ if importlib.util.find_spec('mmcv') is not None:
|
|||
get_dynamic_axes, get_input_shape,
|
||||
get_ir_config, get_model_inputs,
|
||||
get_onnx_config, get_partition_config,
|
||||
get_task_type, is_dynamic_batch,
|
||||
get_precision, get_task_type, is_dynamic_batch,
|
||||
is_dynamic_shape, load_config)
|
||||
|
||||
# yapf: enable
|
||||
|
@ -33,5 +33,5 @@ if importlib.util.find_spec('mmcv') is not None:
|
|||
'get_codebase_config', 'get_common_config', 'get_dynamic_axes',
|
||||
'get_input_shape', 'get_ir_config', 'get_model_inputs',
|
||||
'get_onnx_config', 'get_partition_config', 'get_task_type',
|
||||
'is_dynamic_batch', 'is_dynamic_shape', 'load_config'
|
||||
'is_dynamic_batch', 'is_dynamic_shape', 'load_config', 'get_precision'
|
||||
]
|
||||
|
|
|
@ -380,3 +380,26 @@ def get_dynamic_axes(
|
|||
raise KeyError('No names were found to define dynamic axes.')
|
||||
dynamic_axes = dict(zip(axes_names, dynamic_axes))
|
||||
return dynamic_axes
|
||||
|
||||
|
||||
def get_precision(deploy_cfg: Union[str, mmengine.Config]) -> str:
|
||||
"""Get precision of config.
|
||||
|
||||
Args:
|
||||
deploy_cfg (str | mmengine.Config): The path or content of config.
|
||||
|
||||
Returns:
|
||||
str: The precision of target backend.
|
||||
"""
|
||||
precision = 'FP32'
|
||||
deploy_cfg = load_config(deploy_cfg)[0]
|
||||
backend = get_backend(deploy_cfg=deploy_cfg)
|
||||
if backend == Backend.TENSORRT:
|
||||
common_cfg = get_common_config(deploy_cfg)
|
||||
if common_cfg.get('fp16_mode', False):
|
||||
precision = 'FP16'
|
||||
if common_cfg.get('int8_mode', False):
|
||||
precision = 'INT8'
|
||||
if backend == Backend.NCNN and 'precision' in deploy_cfg['backend_config']:
|
||||
precision = deploy_cfg['backend_config']['precision']
|
||||
return precision
|
||||
|
|
Loading…
Reference in New Issue