mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Feature] Align datasets (#29)
* add test tool and re-orgnize apis.utils * handle topk and refine codes * add cls export and test support * fix lint * move ort into wrapper * resolve conflicts * resolve comments * resolve conflicts * resolve comments and padding mrcnn * resolve comments
This commit is contained in:
parent
90ce7207da
commit
f607f1965b
@ -3,6 +3,6 @@ tensorrt_params = dict(model_params=[
|
||||
dict(
|
||||
save_file='end2end.engine',
|
||||
opt_shape_dict=dict(
|
||||
input=[[1, 3, 224, 224], [4, 3, 224, 224], [32, 3, 224, 224]]),
|
||||
input=[[1, 3, 224, 224], [4, 3, 224, 224], [64, 3, 224, 224]]),
|
||||
max_workspace_size=1 << 30)
|
||||
])
|
||||
|
@ -1,5 +1,13 @@
|
||||
from .extract_model import extract_model
|
||||
from .inference import inference_model
|
||||
from .pytorch2onnx import torch2onnx, torch2onnx_impl
|
||||
from .test import post_process_outputs, prepare_data_loader, single_gpu_test
|
||||
from .utils import (assert_cfg_valid, assert_module_exist,
|
||||
get_classes_from_config, init_backend_model)
|
||||
|
||||
__all__ = ['torch2onnx_impl', 'torch2onnx', 'extract_model', 'inference_model']
|
||||
__all__ = [
|
||||
'torch2onnx_impl', 'torch2onnx', 'extract_model', 'inference_model',
|
||||
'prepare_data_loader', 'assert_module_exist', 'assert_cfg_valid',
|
||||
'init_backend_model', 'get_classes_from_config', 'single_gpu_test',
|
||||
'post_process_outputs'
|
||||
]
|
||||
|
142
mmdeploy/apis/test.py
Normal file
142
mmdeploy/apis/test.py
Normal file
@ -0,0 +1,142 @@
|
||||
import warnings
|
||||
from typing import Any, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmdeploy.apis.utils import assert_module_exist
|
||||
|
||||
|
||||
def prepare_data_loader(codebase: str, model_cfg: Union[str, mmcv.Config]):
|
||||
# load model_cfg if necessary
|
||||
if isinstance(model_cfg, str):
|
||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||
|
||||
if codebase == 'mmcls':
|
||||
from mmcls.datasets import (build_dataloader, build_dataset)
|
||||
assert_module_exist(codebase)
|
||||
# build dataset and dataloader
|
||||
dataset = build_dataset(model_cfg.data.test)
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=model_cfg.data.samples_per_gpu,
|
||||
workers_per_gpu=model_cfg.data.workers_per_gpu,
|
||||
shuffle=False,
|
||||
round_up=False)
|
||||
|
||||
elif codebase == 'mmdet':
|
||||
assert_module_exist(codebase)
|
||||
from mmdet.datasets import (build_dataloader, build_dataset,
|
||||
replace_ImageToTensor)
|
||||
# in case the test dataset is concatenated
|
||||
samples_per_gpu = 1
|
||||
if isinstance(model_cfg.data.test, dict):
|
||||
model_cfg.data.test.test_mode = True
|
||||
samples_per_gpu = model_cfg.data.test.pop('samples_per_gpu', 1)
|
||||
if samples_per_gpu > 1:
|
||||
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
|
||||
model_cfg.data.test.pipeline = replace_ImageToTensor(
|
||||
model_cfg.data.test.pipeline)
|
||||
elif isinstance(model_cfg.data.test, list):
|
||||
for ds_cfg in model_cfg.data.test:
|
||||
ds_cfg.test_mode = True
|
||||
samples_per_gpu = max([
|
||||
ds_cfg.pop('samples_per_gpu', 1)
|
||||
for ds_cfg in model_cfg.data.test
|
||||
])
|
||||
if samples_per_gpu > 1:
|
||||
for ds_cfg in model_cfg.data.test:
|
||||
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
|
||||
|
||||
# build the dataloader
|
||||
dataset = build_dataset(model_cfg.data.test)
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=samples_per_gpu,
|
||||
workers_per_gpu=model_cfg.data.workers_per_gpu,
|
||||
dist=False,
|
||||
shuffle=False)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
||||
return dataset, data_loader
|
||||
|
||||
|
||||
def single_gpu_test(codebase: str,
|
||||
model: nn.Module,
|
||||
data_loader: DataLoader,
|
||||
show: bool = False,
|
||||
out_dir: Any = None,
|
||||
show_score_thr: float = 0.3):
|
||||
if codebase == 'mmcls':
|
||||
assert_module_exist(codebase)
|
||||
from mmcls.apis import single_gpu_test
|
||||
outputs = single_gpu_test(model, data_loader, show, out_dir)
|
||||
elif codebase == 'mmdet':
|
||||
assert_module_exist(codebase)
|
||||
from mmdet.apis import single_gpu_test
|
||||
outputs = single_gpu_test(model, data_loader, show, out_dir,
|
||||
show_score_thr)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
return outputs
|
||||
|
||||
|
||||
def post_process_outputs(outputs,
|
||||
dataset,
|
||||
model_cfg: mmcv.Config,
|
||||
codebase: str,
|
||||
metrics: str = None,
|
||||
out: str = None,
|
||||
metric_options: dict = None,
|
||||
format_only: bool = False):
|
||||
if codebase == 'mmcls':
|
||||
if metrics:
|
||||
results = dataset.evaluate(outputs, metrics, metric_options)
|
||||
for k, v in results.items():
|
||||
print(f'\n{k} : {v:.2f}')
|
||||
else:
|
||||
warnings.warn('Evaluation metrics are not specified.')
|
||||
scores = np.vstack(outputs)
|
||||
pred_score = np.max(scores, axis=1)
|
||||
pred_label = np.argmax(scores, axis=1)
|
||||
pred_class = [dataset.CLASSES[lb] for lb in pred_label]
|
||||
results = {
|
||||
'pred_score': pred_score,
|
||||
'pred_label': pred_label,
|
||||
'pred_class': pred_class
|
||||
}
|
||||
if not out:
|
||||
print('\nthe predicted result for the first element is '
|
||||
f'pred_score = {pred_score[0]:.2f}, '
|
||||
f'pred_label = {pred_label[0]} '
|
||||
f'and pred_class = {pred_class[0]}. '
|
||||
'Specify --out to save all results to files.')
|
||||
if out:
|
||||
print(f'\nwriting results to {out}')
|
||||
mmcv.dump(results, out)
|
||||
|
||||
elif codebase == 'mmdet':
|
||||
if out:
|
||||
print(f'\nwriting results to {out}')
|
||||
mmcv.dump(outputs, out)
|
||||
kwargs = {} if metric_options is None else metric_options
|
||||
if format_only:
|
||||
dataset.format_results(outputs, **kwargs)
|
||||
if metrics:
|
||||
eval_kwargs = model_cfg.get('evaluation', {}).copy()
|
||||
# hard-code way to remove EvalHook args
|
||||
for key in [
|
||||
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
|
||||
'rule'
|
||||
]:
|
||||
eval_kwargs.pop(key, None)
|
||||
eval_kwargs.update(dict(metric=metrics, **kwargs))
|
||||
print(dataset.evaluate(outputs, **eval_kwargs))
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
@ -6,8 +6,24 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def module_exist(module_name: str):
|
||||
return importlib.util.find_spec(module_name) is not None
|
||||
def assert_cfg_valid(cfg: Union[str, mmcv.Config, mmcv.ConfigDict], *args):
|
||||
"""Check config validation."""
|
||||
|
||||
def _assert_cfg_valid_(cfg):
|
||||
if isinstance(cfg, str):
|
||||
cfg = mmcv.Config.fromfile(cfg)
|
||||
if not isinstance(cfg, (mmcv.Config, mmcv.ConfigDict)):
|
||||
raise TypeError('deploy_cfg must be a filename or Config object, '
|
||||
f'but got {type(cfg)}')
|
||||
|
||||
_assert_cfg_valid_(cfg)
|
||||
for cfg in args:
|
||||
_assert_cfg_valid_(cfg)
|
||||
|
||||
|
||||
def assert_module_exist(module_name: str):
|
||||
if importlib.util.find_spec(module_name) is None:
|
||||
raise ImportError(f'Can not import module: {module_name}')
|
||||
|
||||
|
||||
def init_model(codebase: str,
|
||||
@ -17,25 +33,20 @@ def init_model(codebase: str,
|
||||
cfg_options: Optional[Dict] = None):
|
||||
# mmcls
|
||||
if codebase == 'mmcls':
|
||||
if module_exist(codebase):
|
||||
from mmcls.apis import init_model
|
||||
model = init_model(model_cfg, model_checkpoint, device,
|
||||
cfg_options)
|
||||
else:
|
||||
raise ImportError(f'Can not import module: {codebase}')
|
||||
assert_module_exist(codebase)
|
||||
from mmcls.apis import init_model
|
||||
model = init_model(model_cfg, model_checkpoint, device, cfg_options)
|
||||
|
||||
elif codebase == 'mmdet':
|
||||
if module_exist(codebase):
|
||||
from mmdet.apis import init_detector
|
||||
model = init_detector(model_cfg, model_checkpoint, device,
|
||||
cfg_options)
|
||||
else:
|
||||
raise ImportError(f'Can not import module: {codebase}')
|
||||
assert_module_exist(codebase)
|
||||
from mmdet.apis import init_detector
|
||||
model = init_detector(model_cfg, model_checkpoint, device, cfg_options)
|
||||
|
||||
elif codebase == 'mmseg':
|
||||
if module_exist(codebase):
|
||||
from mmseg.apis import init_segmentor
|
||||
model = init_segmentor(model_cfg, model_checkpoint, device)
|
||||
else:
|
||||
raise ImportError(f'Can not import module: {codebase}')
|
||||
assert_module_exist(codebase)
|
||||
from mmseg.apis import init_segmentor
|
||||
model = init_segmentor(model_cfg, model_checkpoint, device)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
||||
@ -54,17 +65,15 @@ def create_input(codebase: str,
|
||||
|
||||
cfg = model_cfg.copy()
|
||||
if codebase == 'mmcls':
|
||||
if module_exist(codebase):
|
||||
from mmdeploy.mmcls.export import create_input
|
||||
return create_input(cfg, imgs, device)
|
||||
else:
|
||||
raise ImportError(f'Can not import module: {codebase}')
|
||||
assert_module_exist(codebase)
|
||||
from mmdeploy.mmcls.export import create_input
|
||||
return create_input(cfg, imgs, device)
|
||||
|
||||
elif codebase == 'mmdet':
|
||||
if module_exist(codebase):
|
||||
from mmdeploy.mmdet.export import create_input
|
||||
return create_input(cfg, imgs, device)
|
||||
else:
|
||||
raise ImportError(f'Can not import module: {codebase}')
|
||||
assert_module_exist(codebase)
|
||||
from mmdeploy.mmdet.export import create_input
|
||||
return create_input(cfg, imgs, device)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
||||
@ -86,45 +95,33 @@ def init_backend_model(model_files: Sequence[str],
|
||||
class_names: Sequence[str],
|
||||
device_id: int = 0):
|
||||
if codebase == 'mmcls':
|
||||
if module_exist(codebase):
|
||||
if backend == 'onnxruntime':
|
||||
from mmdeploy.mmcls.export import ONNXRuntimeClassifier
|
||||
backend_model = ONNXRuntimeClassifier(
|
||||
model_files[0],
|
||||
class_names=class_names,
|
||||
device_id=device_id)
|
||||
elif backend == 'tensorrt':
|
||||
from mmdeploy.mmcls.export import TensorRTClassifier
|
||||
backend_model = TensorRTClassifier(
|
||||
model_files[0],
|
||||
class_names=class_names,
|
||||
device_id=device_id)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Unsupported backend type: {backend}')
|
||||
return backend_model
|
||||
assert_module_exist(codebase)
|
||||
if backend == 'onnxruntime':
|
||||
from mmdeploy.mmcls.export import ONNXRuntimeClassifier
|
||||
backend_model = ONNXRuntimeClassifier(
|
||||
model_files[0], class_names=class_names, device_id=device_id)
|
||||
elif backend == 'tensorrt':
|
||||
from mmdeploy.mmcls.export import TensorRTClassifier
|
||||
backend_model = TensorRTClassifier(
|
||||
model_files[0], class_names=class_names, device_id=device_id)
|
||||
else:
|
||||
raise ImportError(f'Can not import module: {codebase}')
|
||||
raise NotImplementedError(f'Unsupported backend type: {backend}')
|
||||
return backend_model
|
||||
|
||||
elif codebase == 'mmdet':
|
||||
if module_exist(codebase):
|
||||
if backend == 'onnxruntime':
|
||||
from mmdeploy.mmdet.export import ONNXRuntimeDetector
|
||||
backend_model = ONNXRuntimeDetector(
|
||||
model_files[0],
|
||||
class_names=class_names,
|
||||
device_id=device_id)
|
||||
elif backend == 'tensorrt':
|
||||
from mmdeploy.mmdet.export import TensorRTDetector
|
||||
backend_model = TensorRTDetector(
|
||||
model_files[0],
|
||||
class_names=class_names,
|
||||
device_id=device_id)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Unsupported backend type: {backend}')
|
||||
return backend_model
|
||||
assert_module_exist(codebase)
|
||||
if backend == 'onnxruntime':
|
||||
from mmdeploy.mmdet.export import ONNXRuntimeDetector
|
||||
backend_model = ONNXRuntimeDetector(
|
||||
model_files[0], class_names=class_names, device_id=device_id)
|
||||
elif backend == 'tensorrt':
|
||||
from mmdeploy.mmdet.export import TensorRTDetector
|
||||
backend_model = TensorRTDetector(
|
||||
model_files[0], class_names=class_names, device_id=device_id)
|
||||
else:
|
||||
raise ImportError(f'Can not import module: {codebase}')
|
||||
raise NotImplementedError(f'Unsupported backend type: {backend}')
|
||||
return backend_model
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
||||
@ -132,56 +129,51 @@ def init_backend_model(model_files: Sequence[str],
|
||||
def get_classes_from_config(codebase: str, model_cfg: Union[str, mmcv.Config]):
|
||||
model_cfg_str = model_cfg
|
||||
if codebase == 'mmcls':
|
||||
if module_exist(codebase):
|
||||
if isinstance(model_cfg, str):
|
||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(model_cfg)}')
|
||||
assert_module_exist(codebase)
|
||||
if isinstance(model_cfg, str):
|
||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(model_cfg)}')
|
||||
|
||||
from mmcls.datasets import DATASETS
|
||||
module_dict = DATASETS.module_dict
|
||||
data_cfg = model_cfg.data
|
||||
from mmcls.datasets import DATASETS
|
||||
module_dict = DATASETS.module_dict
|
||||
data_cfg = model_cfg.data
|
||||
|
||||
if 'train' in data_cfg:
|
||||
module = module_dict[data_cfg.train.type]
|
||||
elif 'val' in data_cfg:
|
||||
module = module_dict[data_cfg.val.type]
|
||||
elif 'test' in data_cfg:
|
||||
module = module_dict[data_cfg.test.type]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'No dataset config found in: {model_cfg_str}')
|
||||
|
||||
return module.CLASSES
|
||||
if 'train' in data_cfg:
|
||||
module = module_dict[data_cfg.train.type]
|
||||
elif 'val' in data_cfg:
|
||||
module = module_dict[data_cfg.val.type]
|
||||
elif 'test' in data_cfg:
|
||||
module = module_dict[data_cfg.test.type]
|
||||
else:
|
||||
raise ImportError(f'Can not import module: {codebase}')
|
||||
raise RuntimeError(f'No dataset config found in: {model_cfg_str}')
|
||||
|
||||
return module.CLASSES
|
||||
|
||||
if codebase == 'mmdet':
|
||||
if module_exist(codebase):
|
||||
if isinstance(model_cfg, str):
|
||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(model_cfg)}')
|
||||
assert_module_exist(codebase)
|
||||
if isinstance(model_cfg, str):
|
||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(model_cfg)}')
|
||||
|
||||
from mmdet.datasets import DATASETS
|
||||
module_dict = DATASETS.module_dict
|
||||
data_cfg = model_cfg.data
|
||||
from mmdet.datasets import DATASETS
|
||||
module_dict = DATASETS.module_dict
|
||||
data_cfg = model_cfg.data
|
||||
|
||||
if 'train' in data_cfg:
|
||||
module = module_dict[data_cfg.train.type]
|
||||
elif 'val' in data_cfg:
|
||||
module = module_dict[data_cfg.val.type]
|
||||
elif 'test' in data_cfg:
|
||||
module = module_dict[data_cfg.test.type]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'No dataset config found in: {model_cfg_str}')
|
||||
|
||||
return module.CLASSES
|
||||
if 'train' in data_cfg:
|
||||
module = module_dict[data_cfg.train.type]
|
||||
elif 'val' in data_cfg:
|
||||
module = module_dict[data_cfg.val.type]
|
||||
elif 'test' in data_cfg:
|
||||
module = module_dict[data_cfg.test.type]
|
||||
else:
|
||||
raise ImportError(f'Can not import module: {codebase}')
|
||||
raise RuntimeError(f'No dataset config found in: {model_cfg_str}')
|
||||
|
||||
return module.CLASSES
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
||||
@ -195,41 +187,37 @@ def check_model_outputs(codebase: str,
|
||||
show_result=False):
|
||||
show_img = mmcv.imread(image) if isinstance(image, str) else image
|
||||
if codebase == 'mmcls':
|
||||
if module_exist(codebase):
|
||||
output_file = None if show_result else output_file
|
||||
with torch.no_grad():
|
||||
scores = model(**model_inputs, return_loss=False)[0]
|
||||
pred_score = np.max(scores, axis=0)
|
||||
pred_label = np.argmax(scores, axis=0)
|
||||
result = {
|
||||
'pred_label': pred_label,
|
||||
'pred_score': float(pred_score)
|
||||
}
|
||||
result['pred_class'] = model.CLASSES[result['pred_label']]
|
||||
model.show_result(
|
||||
show_img,
|
||||
result,
|
||||
show=True,
|
||||
win_name=backend,
|
||||
out_file=output_file)
|
||||
else:
|
||||
raise ImportError(f'Can not import module: {codebase}')
|
||||
elif codebase == 'mmdet':
|
||||
if module_exist(codebase):
|
||||
output_file = None if show_result else output_file
|
||||
score_thr = 0.3
|
||||
with torch.no_grad():
|
||||
results = model(
|
||||
**model_inputs, return_loss=False, rescale=True)[0]
|
||||
model.show_result(
|
||||
show_img,
|
||||
results,
|
||||
score_thr=score_thr,
|
||||
show=True,
|
||||
win_name=backend,
|
||||
out_file=output_file)
|
||||
assert_module_exist(codebase)
|
||||
output_file = None if show_result else output_file
|
||||
with torch.no_grad():
|
||||
scores = model(**model_inputs, return_loss=False)[0]
|
||||
pred_score = np.max(scores, axis=0)
|
||||
pred_label = np.argmax(scores, axis=0)
|
||||
result = {
|
||||
'pred_label': pred_label,
|
||||
'pred_score': float(pred_score)
|
||||
}
|
||||
result['pred_class'] = model.CLASSES[result['pred_label']]
|
||||
model.show_result(
|
||||
show_img,
|
||||
result,
|
||||
show=True,
|
||||
win_name=backend,
|
||||
out_file=output_file)
|
||||
|
||||
elif codebase == 'mmdet':
|
||||
assert_module_exist(codebase)
|
||||
output_file = None if show_result else output_file
|
||||
score_thr = 0.3
|
||||
with torch.no_grad():
|
||||
results = model(**model_inputs, return_loss=False, rescale=True)[0]
|
||||
model.show_result(
|
||||
show_img,
|
||||
results,
|
||||
score_thr=score_thr,
|
||||
show=True,
|
||||
win_name=backend,
|
||||
out_file=output_file)
|
||||
|
||||
else:
|
||||
raise ImportError(f'Can not import module: {codebase}')
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
@ -1,7 +1,6 @@
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from mmcls.models import BaseClassifier
|
||||
|
||||
@ -11,6 +10,7 @@ class ONNXRuntimeClassifier(BaseClassifier):
|
||||
|
||||
def __init__(self, onnx_file, class_names, device_id):
|
||||
super(ONNXRuntimeClassifier, self).__init__()
|
||||
import onnxruntime as ort
|
||||
sess = ort.InferenceSession(onnx_file)
|
||||
|
||||
providers = ['CPUExecutionProvider']
|
||||
|
@ -1 +1,2 @@
|
||||
from .classifiers import * # noqa: F401,F403
|
||||
from .heads import * # noqa: F401,F403
|
||||
|
13
mmdeploy/mmcls/models/heads/__init__.py
Normal file
13
mmdeploy/mmcls/models/heads/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
from .cls_head import simple_test_of_cls_head
|
||||
from .linear_head import simple_test_of_linear_head
|
||||
from .multi_label_head import simple_test_of_multi_label_head
|
||||
from .multi_label_linear_head import simple_test_of_multi_label_linear_head
|
||||
from .stacked_head import simple_test_of_stacked_head
|
||||
from .vision_transformer_head import simple_test_of_vision_transformer_head
|
||||
|
||||
__all__ = [
|
||||
'simple_test_of_multi_label_linear_head',
|
||||
'simple_test_of_multi_label_head', 'simple_test_of_cls_head',
|
||||
'simple_test_of_linear_head', 'simple_test_of_stacked_head',
|
||||
'simple_test_of_vision_transformer_head'
|
||||
]
|
13
mmdeploy/mmcls/models/heads/cls_head.py
Normal file
13
mmdeploy/mmcls/models/heads/cls_head.py
Normal file
@ -0,0 +1,13 @@
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.heads.ClsHead.simple_test')
|
||||
def simple_test_of_cls_head(ctx, self, cls_score, **kwargs):
|
||||
"""Test without augmentation."""
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
return pred
|
14
mmdeploy/mmcls/models/heads/linear_head.py
Normal file
14
mmdeploy/mmcls/models/heads/linear_head.py
Normal file
@ -0,0 +1,14 @@
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.heads.LinearClsHead.simple_test')
|
||||
def simple_test_of_linear_head(ctx, self, img, **kwargs):
|
||||
"""Test without augmentation."""
|
||||
cls_score = self.fc(img)
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
return pred
|
12
mmdeploy/mmcls/models/heads/multi_label_head.py
Normal file
12
mmdeploy/mmcls/models/heads/multi_label_head.py
Normal file
@ -0,0 +1,12 @@
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.heads.MultiLabelClsHead.simple_test')
|
||||
def simple_test_of_multi_label_head(ctx, self, cls_score, **kwargs):
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
||||
return pred
|
14
mmdeploy/mmcls/models/heads/multi_label_linear_head.py
Normal file
14
mmdeploy/mmcls/models/heads/multi_label_linear_head.py
Normal file
@ -0,0 +1,14 @@
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.heads.MultiLabelLinearClsHead.simple_test')
|
||||
def simple_test_of_multi_label_linear_head(ctx, self, img, **kwargs):
|
||||
"""Test without augmentation."""
|
||||
cls_score = self.fc(img)
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
||||
return pred
|
16
mmdeploy/mmcls/models/heads/stacked_head.py
Normal file
16
mmdeploy/mmcls/models/heads/stacked_head.py
Normal file
@ -0,0 +1,16 @@
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.heads.StackedLinearClsHead.simple_test')
|
||||
def simple_test_of_stacked_head(ctx, self, img, **kwargs):
|
||||
"""Test without augmentation."""
|
||||
cls_score = img
|
||||
for layer in self.layers:
|
||||
cls_score = layer(cls_score)
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
return pred
|
14
mmdeploy/mmcls/models/heads/vision_transformer_head.py
Normal file
14
mmdeploy/mmcls/models/heads/vision_transformer_head.py
Normal file
@ -0,0 +1,14 @@
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.heads.VisionTransformerClsHead.simple_test')
|
||||
def simple_test_of_vision_transformer_head(ctx, self, img, **kwargs):
|
||||
"""Test without augmentation."""
|
||||
cls_score = self.layers(img)
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
return pred
|
@ -1,7 +1,9 @@
|
||||
from .model_wrappers import ONNXRuntimeDetector, TensorRTDetector
|
||||
from .onnx_helper import clip_bboxes
|
||||
from .prepare_input import create_input
|
||||
from .tensorrt_helper import pad_with_value
|
||||
|
||||
__all__ = [
|
||||
'clip_bboxes', 'TensorRTDetector', 'create_input', 'ONNXRuntimeDetector'
|
||||
'clip_bboxes', 'TensorRTDetector', 'create_input', 'ONNXRuntimeDetector',
|
||||
'pad_with_value'
|
||||
]
|
||||
|
@ -74,7 +74,7 @@ class DeployBaseDetector(BaseDetector):
|
||||
img_h, img_w = img_metas[i]['img_shape'][:2]
|
||||
ori_h, ori_w = img_metas[i]['ori_shape'][:2]
|
||||
masks = masks[:, :img_h, :img_w]
|
||||
if rescale:
|
||||
if rescale and batch_masks.shape[1] > 0:
|
||||
masks = masks.astype(np.float32)
|
||||
masks = torch.from_numpy(masks)
|
||||
masks = torch.nn.functional.interpolate(
|
||||
|
18
mmdeploy/mmdet/export/tensorrt_helper.py
Normal file
18
mmdeploy/mmdet/export/tensorrt_helper.py
Normal file
@ -0,0 +1,18 @@
|
||||
import torch
|
||||
|
||||
|
||||
def pad_with_value(x, pad_dim, pad_size, pad_value=None):
|
||||
num_dims = len(x.shape)
|
||||
pad_slice = (slice(None, None, None), ) * num_dims
|
||||
pad_slice = pad_slice[:pad_dim] + (slice(0, 1,
|
||||
1), ) + pad_slice[pad_dim + 1:]
|
||||
repeat_size = [1] * num_dims
|
||||
repeat_size[pad_dim] = pad_size
|
||||
|
||||
x_pad = x.__getitem__(pad_slice)
|
||||
if pad_value is not None:
|
||||
x_pad = x_pad * 0 + pad_value
|
||||
|
||||
x_pad = x_pad.repeat(*repeat_size)
|
||||
x = torch.cat([x, x_pad], dim=pad_dim)
|
||||
return x
|
@ -2,6 +2,7 @@ import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.mmdet.core import multiclass_nms
|
||||
from mmdeploy.mmdet.export import pad_with_value
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
||||
@ -55,15 +56,15 @@ def get_bboxes_of_anchor_head(ctx,
|
||||
|
||||
anchors = anchors.expand_as(bbox_pred)
|
||||
|
||||
enable_nms_pre = True
|
||||
backend = deploy_cfg['backend']
|
||||
# topk in tensorrt does not support shape<k
|
||||
# final level might meet the problem
|
||||
# TODO: support dynamic shape feature with TensorRT for topK op
|
||||
# concate zero to enable topk,
|
||||
if backend == 'tensorrt':
|
||||
enable_nms_pre = (level_id != num_levels - 1)
|
||||
anchors = pad_with_value(anchors, 1, pre_topk)
|
||||
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
|
||||
scores = pad_with_value(scores, 1, pre_topk, 0.)
|
||||
|
||||
if pre_topk > 0 and enable_nms_pre:
|
||||
if pre_topk > 0:
|
||||
# Get maximum scores for foreground classes.
|
||||
if self.use_sigmoid_cls:
|
||||
max_scores, _ = scores.max(-1)
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.mmdet.core import distance2bbox, multiclass_nms
|
||||
from mmdeploy.mmdet.export import pad_with_value
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
||||
@ -59,14 +60,16 @@ def get_bboxes_of_fcos_head(ctx,
|
||||
|
||||
points = points.expand(batch_size, -1, 2)
|
||||
|
||||
enable_nms_pre = True
|
||||
backend = deploy_cfg['backend']
|
||||
# topk in tensorrt does not support shape<k
|
||||
# final level might meet the problem
|
||||
# concate zero to enable topk,
|
||||
if backend == 'tensorrt':
|
||||
enable_nms_pre = (level_id != num_levels - 1)
|
||||
scores = pad_with_value(scores, 1, pre_topk, 0.)
|
||||
centerness = pad_with_value(centerness, 1, pre_topk)
|
||||
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
|
||||
points = pad_with_value(points, 1, pre_topk)
|
||||
|
||||
if pre_topk > 0 and enable_nms_pre:
|
||||
if pre_topk > 0:
|
||||
max_scores, _ = (scores * centerness).max(-1)
|
||||
_, topk_inds = max_scores.topk(pre_topk)
|
||||
batch_inds = torch.arange(batch_size).view(-1,
|
||||
@ -92,7 +95,7 @@ def get_bboxes_of_fcos_head(ctx,
|
||||
if not with_nms:
|
||||
return batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness
|
||||
|
||||
batch_mlvl_scores = batch_mlvl_scores * (batch_mlvl_centerness)
|
||||
batch_mlvl_scores = batch_mlvl_scores * batch_mlvl_centerness
|
||||
post_params = deploy_cfg.post_processing
|
||||
max_output_boxes_per_class = post_params.max_output_boxes_per_class
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.mmdet.core import multiclass_nms
|
||||
from mmdeploy.mmdet.export import pad_with_value
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
||||
@ -50,6 +51,7 @@ def get_bboxes_of_rpn_head(ctx,
|
||||
# be consistent with other head since mmdet v2.0. In mmdet v2.0
|
||||
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
|
||||
scores = cls_score.softmax(-1)[..., 0]
|
||||
scores = scores.reshape(batch_size, -1, 1)
|
||||
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
|
||||
|
||||
# use static anchor if input shape is static
|
||||
@ -58,32 +60,27 @@ def get_bboxes_of_rpn_head(ctx,
|
||||
|
||||
anchors = anchors.expand_as(bbox_pred)
|
||||
|
||||
enable_nms_pre = True
|
||||
backend = deploy_cfg['backend']
|
||||
# topk in tensorrt does not support shape<k
|
||||
# final level might meet the problem
|
||||
# TODO: support dynamic shape feature with TensorRT for topK op
|
||||
# concate zero to enable topk,
|
||||
if backend == 'tensorrt':
|
||||
enable_nms_pre = (level_id != num_levels - 1)
|
||||
scores = pad_with_value(scores, 1, pre_topk, 0.)
|
||||
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
|
||||
anchors = pad_with_value(anchors, 1, pre_topk)
|
||||
|
||||
if pre_topk > 0 and enable_nms_pre:
|
||||
_, topk_inds = scores.topk(pre_topk)
|
||||
if pre_topk > 0:
|
||||
_, topk_inds = scores.squeeze(2).topk(pre_topk)
|
||||
batch_inds = torch.arange(
|
||||
batch_size, device=device).view(-1, 1).expand_as(topk_inds)
|
||||
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
|
||||
transformed_inds = scores.shape[1] * batch_inds + topk_inds
|
||||
scores = scores.reshape(-1, 1)[transformed_inds].reshape(
|
||||
batch_size, -1)
|
||||
bbox_pred = bbox_pred.reshape(-1, 4)[transformed_inds, :].reshape(
|
||||
batch_size, -1, 4)
|
||||
anchors = anchors.reshape(-1, 4)[transformed_inds, :].reshape(
|
||||
batch_size, -1, 4)
|
||||
anchors = anchors[batch_inds, topk_inds, :]
|
||||
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
|
||||
scores = scores[batch_inds, topk_inds, :]
|
||||
mlvl_valid_bboxes.append(bbox_pred)
|
||||
mlvl_scores.append(scores)
|
||||
mlvl_valid_anchors.append(anchors)
|
||||
|
||||
batch_mlvl_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
|
||||
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1).unsqueeze(2)
|
||||
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
|
||||
batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
|
||||
batch_mlvl_bboxes = self.bbox_coder.decode(
|
||||
batch_mlvl_anchors,
|
||||
|
@ -7,7 +7,8 @@ import mmcv
|
||||
import torch.multiprocessing as mp
|
||||
from torch.multiprocessing import Process, set_start_method
|
||||
|
||||
from mmdeploy.apis import extract_model, inference_model, torch2onnx
|
||||
from mmdeploy.apis import (assert_cfg_valid, extract_model, inference_model,
|
||||
torch2onnx)
|
||||
|
||||
|
||||
def parse_args():
|
||||
@ -70,9 +71,7 @@ def main():
|
||||
|
||||
# load deploy_cfg
|
||||
deploy_cfg = mmcv.Config.fromfile(deploy_cfg_path)
|
||||
if not isinstance(deploy_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
||||
raise TypeError('deploy_cfg must be a filename or Config object, '
|
||||
f'but got {type(deploy_cfg)}')
|
||||
assert_cfg_valid(deploy_cfg, model_cfg_path)
|
||||
|
||||
# create work_dir if not
|
||||
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
|
||||
|
101
tools/test.py
Normal file
101
tools/test.py
Normal file
@ -0,0 +1,101 @@
|
||||
import argparse
|
||||
|
||||
import mmcv
|
||||
from mmcv import DictAction
|
||||
from mmcv.parallel import MMDataParallel
|
||||
|
||||
from mmdeploy.apis import (init_backend_model, post_process_outputs,
|
||||
prepare_data_loader, single_gpu_test)
|
||||
from mmdeploy.apis.utils import assert_cfg_valid, get_classes_from_config
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='MMDeploy test (and eval) a backend.')
|
||||
parser.add_argument('deploy_cfg', help='Deploy config path')
|
||||
parser.add_argument('model_cfg', help='Model config path')
|
||||
parser.add_argument('model', help='Input model file.')
|
||||
parser.add_argument('--out', help='output result file in pickle format')
|
||||
parser.add_argument(
|
||||
'--format-only',
|
||||
action='store_true',
|
||||
help='Format the output results without perform evaluation. It is'
|
||||
'useful when you want to format the result to a specific format and '
|
||||
'submit it to the test server')
|
||||
parser.add_argument(
|
||||
'--metrics',
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='evaluation metrics, which depends on the codebase and the '
|
||||
'dataset, e.g., "bbox", "segm", "proposal" for COCO, and "mAP", '
|
||||
'"recall" for PASCAL VOC in mmdet; "accuracy", "precision", "recall", '
|
||||
'"f1_score", "support" for single label dataset, and "mAP", "CP", "CR"'
|
||||
', "CF1", "OP", "OR", "OF1" for multi-label dataset in mmcls')
|
||||
parser.add_argument('--show', action='store_true', help='show results')
|
||||
parser.add_argument(
|
||||
'--show-dir', help='directory where painted images will be saved')
|
||||
parser.add_argument(
|
||||
'--show-score-thr',
|
||||
type=float,
|
||||
default=0.3,
|
||||
help='score threshold (default: 0.3)')
|
||||
parser.add_argument(
|
||||
'--device', help='device used for conversion', default='cpu')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
parser.add_argument(
|
||||
'--metric-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='custom options for evaluation, the key-value pair in xxx=yyy '
|
||||
'format will be kwargs for dataset.evaluate() function')
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
|
||||
raise ValueError('The output file must be a pkl file.')
|
||||
|
||||
deploy_cfg_path = args.deploy_cfg
|
||||
model_cfg_path = args.model_cfg
|
||||
|
||||
# load deploy_cfg
|
||||
deploy_cfg = mmcv.Config.fromfile(deploy_cfg_path)
|
||||
model_cfg = mmcv.Config.fromfile(model_cfg_path)
|
||||
assert_cfg_valid(deploy_cfg, model_cfg)
|
||||
|
||||
# prepare the dataset loader
|
||||
codebase = deploy_cfg['codebase']
|
||||
dataset, data_loader = prepare_data_loader(codebase, model_cfg)
|
||||
|
||||
# load the model of the backend
|
||||
device_id = -1 if args.device == 'cpu' else 0
|
||||
backend = deploy_cfg.get('backend', 'default')
|
||||
model = init_backend_model([args.model],
|
||||
codebase=codebase,
|
||||
backend=backend,
|
||||
class_names=get_classes_from_config(
|
||||
codebase, model_cfg),
|
||||
device_id=device_id)
|
||||
|
||||
model = MMDataParallel(model, device_ids=[0])
|
||||
outputs = single_gpu_test(codebase, model, data_loader, args.show,
|
||||
args.show_dir, args.show_score_thr)
|
||||
|
||||
post_process_outputs(outputs, dataset, model_cfg, codebase, args.metrics,
|
||||
args.out, args.metric_options, args.format_only)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user