[Feature] Align datasets (#29)

* add test tool and re-orgnize apis.utils

* handle topk and refine codes

* add cls export and test support

* fix lint

* move ort into wrapper

* resolve conflicts

* resolve comments

* resolve conflicts

* resolve comments and padding mrcnn

* resolve comments
This commit is contained in:
AllentDan 2021-08-03 17:12:44 +08:00 committed by GitHub
parent 90ce7207da
commit f607f1965b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 532 additions and 176 deletions

View File

@ -3,6 +3,6 @@ tensorrt_params = dict(model_params=[
dict(
save_file='end2end.engine',
opt_shape_dict=dict(
input=[[1, 3, 224, 224], [4, 3, 224, 224], [32, 3, 224, 224]]),
input=[[1, 3, 224, 224], [4, 3, 224, 224], [64, 3, 224, 224]]),
max_workspace_size=1 << 30)
])

View File

@ -1,5 +1,13 @@
from .extract_model import extract_model
from .inference import inference_model
from .pytorch2onnx import torch2onnx, torch2onnx_impl
from .test import post_process_outputs, prepare_data_loader, single_gpu_test
from .utils import (assert_cfg_valid, assert_module_exist,
get_classes_from_config, init_backend_model)
__all__ = ['torch2onnx_impl', 'torch2onnx', 'extract_model', 'inference_model']
__all__ = [
'torch2onnx_impl', 'torch2onnx', 'extract_model', 'inference_model',
'prepare_data_loader', 'assert_module_exist', 'assert_cfg_valid',
'init_backend_model', 'get_classes_from_config', 'single_gpu_test',
'post_process_outputs'
]

142
mmdeploy/apis/test.py Normal file
View File

@ -0,0 +1,142 @@
import warnings
from typing import Any, Union
import mmcv
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from mmdeploy.apis.utils import assert_module_exist
def prepare_data_loader(codebase: str, model_cfg: Union[str, mmcv.Config]):
# load model_cfg if necessary
if isinstance(model_cfg, str):
model_cfg = mmcv.Config.fromfile(model_cfg)
if codebase == 'mmcls':
from mmcls.datasets import (build_dataloader, build_dataset)
assert_module_exist(codebase)
# build dataset and dataloader
dataset = build_dataset(model_cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=model_cfg.data.samples_per_gpu,
workers_per_gpu=model_cfg.data.workers_per_gpu,
shuffle=False,
round_up=False)
elif codebase == 'mmdet':
assert_module_exist(codebase)
from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor)
# in case the test dataset is concatenated
samples_per_gpu = 1
if isinstance(model_cfg.data.test, dict):
model_cfg.data.test.test_mode = True
samples_per_gpu = model_cfg.data.test.pop('samples_per_gpu', 1)
if samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
model_cfg.data.test.pipeline = replace_ImageToTensor(
model_cfg.data.test.pipeline)
elif isinstance(model_cfg.data.test, list):
for ds_cfg in model_cfg.data.test:
ds_cfg.test_mode = True
samples_per_gpu = max([
ds_cfg.pop('samples_per_gpu', 1)
for ds_cfg in model_cfg.data.test
])
if samples_per_gpu > 1:
for ds_cfg in model_cfg.data.test:
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
# build the dataloader
dataset = build_dataset(model_cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=model_cfg.data.workers_per_gpu,
dist=False,
shuffle=False)
else:
raise NotImplementedError(f'Unknown codebase type: {codebase}')
return dataset, data_loader
def single_gpu_test(codebase: str,
model: nn.Module,
data_loader: DataLoader,
show: bool = False,
out_dir: Any = None,
show_score_thr: float = 0.3):
if codebase == 'mmcls':
assert_module_exist(codebase)
from mmcls.apis import single_gpu_test
outputs = single_gpu_test(model, data_loader, show, out_dir)
elif codebase == 'mmdet':
assert_module_exist(codebase)
from mmdet.apis import single_gpu_test
outputs = single_gpu_test(model, data_loader, show, out_dir,
show_score_thr)
else:
raise NotImplementedError(f'Unknown codebase type: {codebase}')
return outputs
def post_process_outputs(outputs,
dataset,
model_cfg: mmcv.Config,
codebase: str,
metrics: str = None,
out: str = None,
metric_options: dict = None,
format_only: bool = False):
if codebase == 'mmcls':
if metrics:
results = dataset.evaluate(outputs, metrics, metric_options)
for k, v in results.items():
print(f'\n{k} : {v:.2f}')
else:
warnings.warn('Evaluation metrics are not specified.')
scores = np.vstack(outputs)
pred_score = np.max(scores, axis=1)
pred_label = np.argmax(scores, axis=1)
pred_class = [dataset.CLASSES[lb] for lb in pred_label]
results = {
'pred_score': pred_score,
'pred_label': pred_label,
'pred_class': pred_class
}
if not out:
print('\nthe predicted result for the first element is '
f'pred_score = {pred_score[0]:.2f}, '
f'pred_label = {pred_label[0]} '
f'and pred_class = {pred_class[0]}. '
'Specify --out to save all results to files.')
if out:
print(f'\nwriting results to {out}')
mmcv.dump(results, out)
elif codebase == 'mmdet':
if out:
print(f'\nwriting results to {out}')
mmcv.dump(outputs, out)
kwargs = {} if metric_options is None else metric_options
if format_only:
dataset.format_results(outputs, **kwargs)
if metrics:
eval_kwargs = model_cfg.get('evaluation', {}).copy()
# hard-code way to remove EvalHook args
for key in [
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
'rule'
]:
eval_kwargs.pop(key, None)
eval_kwargs.update(dict(metric=metrics, **kwargs))
print(dataset.evaluate(outputs, **eval_kwargs))
else:
raise NotImplementedError(f'Unknown codebase type: {codebase}')

View File

@ -6,8 +6,24 @@ import numpy as np
import torch
def module_exist(module_name: str):
return importlib.util.find_spec(module_name) is not None
def assert_cfg_valid(cfg: Union[str, mmcv.Config, mmcv.ConfigDict], *args):
"""Check config validation."""
def _assert_cfg_valid_(cfg):
if isinstance(cfg, str):
cfg = mmcv.Config.fromfile(cfg)
if not isinstance(cfg, (mmcv.Config, mmcv.ConfigDict)):
raise TypeError('deploy_cfg must be a filename or Config object, '
f'but got {type(cfg)}')
_assert_cfg_valid_(cfg)
for cfg in args:
_assert_cfg_valid_(cfg)
def assert_module_exist(module_name: str):
if importlib.util.find_spec(module_name) is None:
raise ImportError(f'Can not import module: {module_name}')
def init_model(codebase: str,
@ -17,25 +33,20 @@ def init_model(codebase: str,
cfg_options: Optional[Dict] = None):
# mmcls
if codebase == 'mmcls':
if module_exist(codebase):
assert_module_exist(codebase)
from mmcls.apis import init_model
model = init_model(model_cfg, model_checkpoint, device,
cfg_options)
else:
raise ImportError(f'Can not import module: {codebase}')
model = init_model(model_cfg, model_checkpoint, device, cfg_options)
elif codebase == 'mmdet':
if module_exist(codebase):
assert_module_exist(codebase)
from mmdet.apis import init_detector
model = init_detector(model_cfg, model_checkpoint, device,
cfg_options)
else:
raise ImportError(f'Can not import module: {codebase}')
model = init_detector(model_cfg, model_checkpoint, device, cfg_options)
elif codebase == 'mmseg':
if module_exist(codebase):
assert_module_exist(codebase)
from mmseg.apis import init_segmentor
model = init_segmentor(model_cfg, model_checkpoint, device)
else:
raise ImportError(f'Can not import module: {codebase}')
else:
raise NotImplementedError(f'Unknown codebase type: {codebase}')
@ -54,17 +65,15 @@ def create_input(codebase: str,
cfg = model_cfg.copy()
if codebase == 'mmcls':
if module_exist(codebase):
assert_module_exist(codebase)
from mmdeploy.mmcls.export import create_input
return create_input(cfg, imgs, device)
else:
raise ImportError(f'Can not import module: {codebase}')
elif codebase == 'mmdet':
if module_exist(codebase):
assert_module_exist(codebase)
from mmdeploy.mmdet.export import create_input
return create_input(cfg, imgs, device)
else:
raise ImportError(f'Can not import module: {codebase}')
else:
raise NotImplementedError(f'Unknown codebase type: {codebase}')
@ -86,45 +95,33 @@ def init_backend_model(model_files: Sequence[str],
class_names: Sequence[str],
device_id: int = 0):
if codebase == 'mmcls':
if module_exist(codebase):
assert_module_exist(codebase)
if backend == 'onnxruntime':
from mmdeploy.mmcls.export import ONNXRuntimeClassifier
backend_model = ONNXRuntimeClassifier(
model_files[0],
class_names=class_names,
device_id=device_id)
model_files[0], class_names=class_names, device_id=device_id)
elif backend == 'tensorrt':
from mmdeploy.mmcls.export import TensorRTClassifier
backend_model = TensorRTClassifier(
model_files[0],
class_names=class_names,
device_id=device_id)
model_files[0], class_names=class_names, device_id=device_id)
else:
raise NotImplementedError(
f'Unsupported backend type: {backend}')
raise NotImplementedError(f'Unsupported backend type: {backend}')
return backend_model
else:
raise ImportError(f'Can not import module: {codebase}')
elif codebase == 'mmdet':
if module_exist(codebase):
assert_module_exist(codebase)
if backend == 'onnxruntime':
from mmdeploy.mmdet.export import ONNXRuntimeDetector
backend_model = ONNXRuntimeDetector(
model_files[0],
class_names=class_names,
device_id=device_id)
model_files[0], class_names=class_names, device_id=device_id)
elif backend == 'tensorrt':
from mmdeploy.mmdet.export import TensorRTDetector
backend_model = TensorRTDetector(
model_files[0],
class_names=class_names,
device_id=device_id)
model_files[0], class_names=class_names, device_id=device_id)
else:
raise NotImplementedError(
f'Unsupported backend type: {backend}')
raise NotImplementedError(f'Unsupported backend type: {backend}')
return backend_model
else:
raise ImportError(f'Can not import module: {codebase}')
else:
raise NotImplementedError(f'Unknown codebase type: {codebase}')
@ -132,7 +129,7 @@ def init_backend_model(model_files: Sequence[str],
def get_classes_from_config(codebase: str, model_cfg: Union[str, mmcv.Config]):
model_cfg_str = model_cfg
if codebase == 'mmcls':
if module_exist(codebase):
assert_module_exist(codebase)
if isinstance(model_cfg, str):
model_cfg = mmcv.Config.fromfile(model_cfg)
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
@ -150,15 +147,12 @@ def get_classes_from_config(codebase: str, model_cfg: Union[str, mmcv.Config]):
elif 'test' in data_cfg:
module = module_dict[data_cfg.test.type]
else:
raise RuntimeError(
f'No dataset config found in: {model_cfg_str}')
raise RuntimeError(f'No dataset config found in: {model_cfg_str}')
return module.CLASSES
else:
raise ImportError(f'Can not import module: {codebase}')
if codebase == 'mmdet':
if module_exist(codebase):
assert_module_exist(codebase)
if isinstance(model_cfg, str):
model_cfg = mmcv.Config.fromfile(model_cfg)
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
@ -176,12 +170,10 @@ def get_classes_from_config(codebase: str, model_cfg: Union[str, mmcv.Config]):
elif 'test' in data_cfg:
module = module_dict[data_cfg.test.type]
else:
raise RuntimeError(
f'No dataset config found in: {model_cfg_str}')
raise RuntimeError(f'No dataset config found in: {model_cfg_str}')
return module.CLASSES
else:
raise ImportError(f'Can not import module: {codebase}')
else:
raise NotImplementedError(f'Unknown codebase type: {codebase}')
@ -195,7 +187,7 @@ def check_model_outputs(codebase: str,
show_result=False):
show_img = mmcv.imread(image) if isinstance(image, str) else image
if codebase == 'mmcls':
if module_exist(codebase):
assert_module_exist(codebase)
output_file = None if show_result else output_file
with torch.no_grad():
scores = model(**model_inputs, return_loss=False)[0]
@ -212,15 +204,13 @@ def check_model_outputs(codebase: str,
show=True,
win_name=backend,
out_file=output_file)
else:
raise ImportError(f'Can not import module: {codebase}')
elif codebase == 'mmdet':
if module_exist(codebase):
assert_module_exist(codebase)
output_file = None if show_result else output_file
score_thr = 0.3
with torch.no_grad():
results = model(
**model_inputs, return_loss=False, rescale=True)[0]
results = model(**model_inputs, return_loss=False, rescale=True)[0]
model.show_result(
show_img,
results,
@ -229,7 +219,5 @@ def check_model_outputs(codebase: str,
win_name=backend,
out_file=output_file)
else:
raise ImportError(f'Can not import module: {codebase}')
else:
raise NotImplementedError(f'Unknown codebase type: {codebase}')

View File

@ -1,7 +1,6 @@
import warnings
import numpy as np
import onnxruntime as ort
import torch
from mmcls.models import BaseClassifier
@ -11,6 +10,7 @@ class ONNXRuntimeClassifier(BaseClassifier):
def __init__(self, onnx_file, class_names, device_id):
super(ONNXRuntimeClassifier, self).__init__()
import onnxruntime as ort
sess = ort.InferenceSession(onnx_file)
providers = ['CPUExecutionProvider']

View File

@ -1 +1,2 @@
from .classifiers import * # noqa: F401,F403
from .heads import * # noqa: F401,F403

View File

@ -0,0 +1,13 @@
from .cls_head import simple_test_of_cls_head
from .linear_head import simple_test_of_linear_head
from .multi_label_head import simple_test_of_multi_label_head
from .multi_label_linear_head import simple_test_of_multi_label_linear_head
from .stacked_head import simple_test_of_stacked_head
from .vision_transformer_head import simple_test_of_vision_transformer_head
__all__ = [
'simple_test_of_multi_label_linear_head',
'simple_test_of_multi_label_head', 'simple_test_of_cls_head',
'simple_test_of_linear_head', 'simple_test_of_stacked_head',
'simple_test_of_vision_transformer_head'
]

View File

@ -0,0 +1,13 @@
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.heads.ClsHead.simple_test')
def simple_test_of_cls_head(ctx, self, cls_score, **kwargs):
"""Test without augmentation."""
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
return pred

View File

@ -0,0 +1,14 @@
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.heads.LinearClsHead.simple_test')
def simple_test_of_linear_head(ctx, self, img, **kwargs):
"""Test without augmentation."""
cls_score = self.fc(img)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
return pred

View File

@ -0,0 +1,12 @@
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.heads.MultiLabelClsHead.simple_test')
def simple_test_of_multi_label_head(ctx, self, cls_score, **kwargs):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.sigmoid(cls_score) if cls_score is not None else None
return pred

View File

@ -0,0 +1,14 @@
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.heads.MultiLabelLinearClsHead.simple_test')
def simple_test_of_multi_label_linear_head(ctx, self, img, **kwargs):
"""Test without augmentation."""
cls_score = self.fc(img)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.sigmoid(cls_score) if cls_score is not None else None
return pred

View File

@ -0,0 +1,16 @@
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.heads.StackedLinearClsHead.simple_test')
def simple_test_of_stacked_head(ctx, self, img, **kwargs):
"""Test without augmentation."""
cls_score = img
for layer in self.layers:
cls_score = layer(cls_score)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
return pred

View File

@ -0,0 +1,14 @@
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.heads.VisionTransformerClsHead.simple_test')
def simple_test_of_vision_transformer_head(ctx, self, img, **kwargs):
"""Test without augmentation."""
cls_score = self.layers(img)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
return pred

View File

@ -1,7 +1,9 @@
from .model_wrappers import ONNXRuntimeDetector, TensorRTDetector
from .onnx_helper import clip_bboxes
from .prepare_input import create_input
from .tensorrt_helper import pad_with_value
__all__ = [
'clip_bboxes', 'TensorRTDetector', 'create_input', 'ONNXRuntimeDetector'
'clip_bboxes', 'TensorRTDetector', 'create_input', 'ONNXRuntimeDetector',
'pad_with_value'
]

View File

@ -74,7 +74,7 @@ class DeployBaseDetector(BaseDetector):
img_h, img_w = img_metas[i]['img_shape'][:2]
ori_h, ori_w = img_metas[i]['ori_shape'][:2]
masks = masks[:, :img_h, :img_w]
if rescale:
if rescale and batch_masks.shape[1] > 0:
masks = masks.astype(np.float32)
masks = torch.from_numpy(masks)
masks = torch.nn.functional.interpolate(

View File

@ -0,0 +1,18 @@
import torch
def pad_with_value(x, pad_dim, pad_size, pad_value=None):
num_dims = len(x.shape)
pad_slice = (slice(None, None, None), ) * num_dims
pad_slice = pad_slice[:pad_dim] + (slice(0, 1,
1), ) + pad_slice[pad_dim + 1:]
repeat_size = [1] * num_dims
repeat_size[pad_dim] = pad_size
x_pad = x.__getitem__(pad_slice)
if pad_value is not None:
x_pad = x_pad * 0 + pad_value
x_pad = x_pad.repeat(*repeat_size)
x = torch.cat([x, x_pad], dim=pad_dim)
return x

View File

@ -2,6 +2,7 @@ import torch
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.mmdet.core import multiclass_nms
from mmdeploy.mmdet.export import pad_with_value
from mmdeploy.utils import is_dynamic_shape
@ -55,15 +56,15 @@ def get_bboxes_of_anchor_head(ctx,
anchors = anchors.expand_as(bbox_pred)
enable_nms_pre = True
backend = deploy_cfg['backend']
# topk in tensorrt does not support shape<k
# final level might meet the problem
# TODO: support dynamic shape feature with TensorRT for topK op
# concate zero to enable topk,
if backend == 'tensorrt':
enable_nms_pre = (level_id != num_levels - 1)
anchors = pad_with_value(anchors, 1, pre_topk)
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
scores = pad_with_value(scores, 1, pre_topk, 0.)
if pre_topk > 0 and enable_nms_pre:
if pre_topk > 0:
# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = scores.max(-1)

View File

@ -2,6 +2,7 @@ import torch
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.mmdet.core import distance2bbox, multiclass_nms
from mmdeploy.mmdet.export import pad_with_value
from mmdeploy.utils import is_dynamic_shape
@ -59,14 +60,16 @@ def get_bboxes_of_fcos_head(ctx,
points = points.expand(batch_size, -1, 2)
enable_nms_pre = True
backend = deploy_cfg['backend']
# topk in tensorrt does not support shape<k
# final level might meet the problem
# concate zero to enable topk,
if backend == 'tensorrt':
enable_nms_pre = (level_id != num_levels - 1)
scores = pad_with_value(scores, 1, pre_topk, 0.)
centerness = pad_with_value(centerness, 1, pre_topk)
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
points = pad_with_value(points, 1, pre_topk)
if pre_topk > 0 and enable_nms_pre:
if pre_topk > 0:
max_scores, _ = (scores * centerness).max(-1)
_, topk_inds = max_scores.topk(pre_topk)
batch_inds = torch.arange(batch_size).view(-1,
@ -92,7 +95,7 @@ def get_bboxes_of_fcos_head(ctx,
if not with_nms:
return batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness
batch_mlvl_scores = batch_mlvl_scores * (batch_mlvl_centerness)
batch_mlvl_scores = batch_mlvl_scores * batch_mlvl_centerness
post_params = deploy_cfg.post_processing
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)

View File

@ -2,6 +2,7 @@ import torch
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.mmdet.core import multiclass_nms
from mmdeploy.mmdet.export import pad_with_value
from mmdeploy.utils import is_dynamic_shape
@ -50,6 +51,7 @@ def get_bboxes_of_rpn_head(ctx,
# be consistent with other head since mmdet v2.0. In mmdet v2.0
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
scores = cls_score.softmax(-1)[..., 0]
scores = scores.reshape(batch_size, -1, 1)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
# use static anchor if input shape is static
@ -58,32 +60,27 @@ def get_bboxes_of_rpn_head(ctx,
anchors = anchors.expand_as(bbox_pred)
enable_nms_pre = True
backend = deploy_cfg['backend']
# topk in tensorrt does not support shape<k
# final level might meet the problem
# TODO: support dynamic shape feature with TensorRT for topK op
# concate zero to enable topk,
if backend == 'tensorrt':
enable_nms_pre = (level_id != num_levels - 1)
scores = pad_with_value(scores, 1, pre_topk, 0.)
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
anchors = pad_with_value(anchors, 1, pre_topk)
if pre_topk > 0 and enable_nms_pre:
_, topk_inds = scores.topk(pre_topk)
if pre_topk > 0:
_, topk_inds = scores.squeeze(2).topk(pre_topk)
batch_inds = torch.arange(
batch_size, device=device).view(-1, 1).expand_as(topk_inds)
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
transformed_inds = scores.shape[1] * batch_inds + topk_inds
scores = scores.reshape(-1, 1)[transformed_inds].reshape(
batch_size, -1)
bbox_pred = bbox_pred.reshape(-1, 4)[transformed_inds, :].reshape(
batch_size, -1, 4)
anchors = anchors.reshape(-1, 4)[transformed_inds, :].reshape(
batch_size, -1, 4)
anchors = anchors[batch_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
mlvl_valid_bboxes.append(bbox_pred)
mlvl_scores.append(scores)
mlvl_valid_anchors.append(anchors)
batch_mlvl_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1).unsqueeze(2)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
batch_mlvl_bboxes = self.bbox_coder.decode(
batch_mlvl_anchors,

View File

@ -7,7 +7,8 @@ import mmcv
import torch.multiprocessing as mp
from torch.multiprocessing import Process, set_start_method
from mmdeploy.apis import extract_model, inference_model, torch2onnx
from mmdeploy.apis import (assert_cfg_valid, extract_model, inference_model,
torch2onnx)
def parse_args():
@ -70,9 +71,7 @@ def main():
# load deploy_cfg
deploy_cfg = mmcv.Config.fromfile(deploy_cfg_path)
if not isinstance(deploy_cfg, (mmcv.Config, mmcv.ConfigDict)):
raise TypeError('deploy_cfg must be a filename or Config object, '
f'but got {type(deploy_cfg)}')
assert_cfg_valid(deploy_cfg, model_cfg_path)
# create work_dir if not
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))

101
tools/test.py Normal file
View File

@ -0,0 +1,101 @@
import argparse
import mmcv
from mmcv import DictAction
from mmcv.parallel import MMDataParallel
from mmdeploy.apis import (init_backend_model, post_process_outputs,
prepare_data_loader, single_gpu_test)
from mmdeploy.apis.utils import assert_cfg_valid, get_classes_from_config
def parse_args():
parser = argparse.ArgumentParser(
description='MMDeploy test (and eval) a backend.')
parser.add_argument('deploy_cfg', help='Deploy config path')
parser.add_argument('model_cfg', help='Model config path')
parser.add_argument('model', help='Input model file.')
parser.add_argument('--out', help='output result file in pickle format')
parser.add_argument(
'--format-only',
action='store_true',
help='Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server')
parser.add_argument(
'--metrics',
type=str,
nargs='+',
help='evaluation metrics, which depends on the codebase and the '
'dataset, e.g., "bbox", "segm", "proposal" for COCO, and "mAP", '
'"recall" for PASCAL VOC in mmdet; "accuracy", "precision", "recall", '
'"f1_score", "support" for single label dataset, and "mAP", "CP", "CR"'
', "CF1", "OP", "OR", "OF1" for multi-label dataset in mmcls')
parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument(
'--show-dir', help='directory where painted images will be saved')
parser.add_argument(
'--show-score-thr',
type=float,
default=0.3,
help='score threshold (default: 0.3)')
parser.add_argument(
'--device', help='device used for conversion', default='cpu')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--metric-options',
nargs='+',
action=DictAction,
help='custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function')
args = parser.parse_args()
return args
def main():
args = parse_args()
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
raise ValueError('The output file must be a pkl file.')
deploy_cfg_path = args.deploy_cfg
model_cfg_path = args.model_cfg
# load deploy_cfg
deploy_cfg = mmcv.Config.fromfile(deploy_cfg_path)
model_cfg = mmcv.Config.fromfile(model_cfg_path)
assert_cfg_valid(deploy_cfg, model_cfg)
# prepare the dataset loader
codebase = deploy_cfg['codebase']
dataset, data_loader = prepare_data_loader(codebase, model_cfg)
# load the model of the backend
device_id = -1 if args.device == 'cpu' else 0
backend = deploy_cfg.get('backend', 'default')
model = init_backend_model([args.model],
codebase=codebase,
backend=backend,
class_names=get_classes_from_config(
codebase, model_cfg),
device_id=device_id)
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(codebase, model, data_loader, args.show,
args.show_dir, args.show_score_thr)
post_process_outputs(outputs, dataset, model_cfg, codebase, args.metrics,
args.out, args.metric_options, args.format_only)
if __name__ == '__main__':
main()