[Feature] much better retinanet (#6)

* better retinanet support

* prepare split export tensorrt

* optimizer cfg

* free anchor when static shape

* fix docstring

* use function rewriter instead of module rewriter on retinanet

* fix bug of mmcls tensorrt config

* add single stage mark, static shape support
pull/12/head
q.yao 2021-07-01 17:32:33 +08:00 committed by GitHub
parent 27880afdcd
commit 52fd08febd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 154 additions and 298 deletions

View File

@ -1,3 +1,44 @@
# MMDeployment
WIP
## Installation
- Build backend ops
- Build with onnxruntime support
```bash
mkdir build
cd build
cmake -DBUILD_ONNXRUNTIME_OPS=ON -DONNXRUNTIME_DIR=${PATH_TO_ONNXRUNTIME} ..
make -j10
```
- Build with tensorrt support
```bash
mkdir build
cd build
cmake -DBUILD_TENSORRT_OPS=ON -DTENSORRT_DIR=${PATH_TO_TENSORRT} ..
make -j10
```
- Or you can add multiple flags to build multiple backend ops.
- Setup project
```bash
python setup.py develop
```
## Usage
```bash
python ./tools/deploy.py \
${DEPLOY_CFG_PATH} \
${MODEL_CFG_PATH} \
${MODEL_CHECKPOINT_PATH} \
${INPUT_IMG} \
--work-dir ${WORK_DIR} \
--device ${DEVICE} \
--log-level INFO
```

View File

@ -1,8 +1,8 @@
_base_ = ['./mmcls_base.py', '../_base_/backends/tensorrt.py']
tensorrt_param = dict(model_params=[
dict(
save_file='end2end.engine',
opt_shape_dict=dict(
save_file='end2end.engine',
input=[[1, 3, 224, 224], [4, 3, 224, 224], [32, 3, 224, 224]]),
max_workspace_size=1 << 30)
])

View File

@ -1,10 +1,3 @@
from .onnx_helper import add_dummy_nms_for_onnx, dynamic_clip_for_onnx
from .pytorch2onnx import (build_model_from_cfg,
generate_inputs_and_wrap_model,
preprocess_example_input)
__all__ = [
'build_model_from_cfg', 'generate_inputs_and_wrap_model',
'preprocess_example_input', 'add_dummy_nms_for_onnx',
'dynamic_clip_for_onnx'
]
__all__ = ['add_dummy_nms_for_onnx', 'dynamic_clip_for_onnx']

View File

@ -1,161 +0,0 @@
from functools import partial
import mmcv
import numpy as np
import torch
from mmcv.runner import load_checkpoint
def generate_inputs_and_wrap_model(config_path,
checkpoint_path,
input_config,
cfg_options=None):
"""Prepare sample input and wrap model for ONNX export.
The ONNX export API only accept args, and all inputs should be
torch.Tensor or corresponding types (such as tuple of tensor).
So we should call this function before exporting. This function will:
1. generate corresponding inputs which are used to execute the model.
2. Wrap the model's forward function.
For example, the MMDet models' forward function has a parameter
``return_loss:bool``. As we want to set it as False while export API
supports neither bool type or kwargs. So we have to replace the forward
like: ``model.forward = partial(model.forward, return_loss=False)``
Args:
config_path (str): the OpenMMLab config for the model we want to
export to ONNX
checkpoint_path (str): Path to the corresponding checkpoint
input_config (dict): the exactly data in this dict depends on the
framework. For MMSeg, we can just declare the input shape,
and generate the dummy data accordingly. However, for MMDet,
we may pass the real img path, or the NMS will return None
as there is no legal bbox.
Returns:
tuple: (model, tensor_data) wrapped model which can be called by \
model(*tensor_data) and a list of inputs which are used to execute \
the model while exporting.
"""
model = build_model_from_cfg(
config_path, checkpoint_path, cfg_options=cfg_options)
one_img, one_meta = preprocess_example_input(input_config)
tensor_data = [one_img]
model.forward = partial(
model.forward, img_metas=[[one_meta]], return_loss=False)
# pytorch has some bug in pytorch1.3, we have to fix it
# by replacing these existing op
opset_version = 11
# put the import within the function thus it will not cause import error
# when not using this function
try:
from mmcv.onnx.symbolic import register_extra_symbolics
except ModuleNotFoundError:
raise NotImplementedError('please update mmcv to version>=v1.0.4')
register_extra_symbolics(opset_version)
return model, tensor_data
def build_model_from_cfg(config_path, checkpoint_path, cfg_options=None):
"""Build a model from config and load the given checkpoint.
Args:
config_path (str): the OpenMMLab config for the model we want to
export to ONNX
checkpoint_path (str): Path to the corresponding checkpoint
Returns:
torch.nn.Module: the built model
"""
from mmdet.models import build_detector
cfg = mmcv.Config.fromfile(config_path)
if cfg_options is not None:
cfg.merge_from_dict(cfg_options)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
cfg.model.pretrained = None
cfg.data.test.test_mode = True
# build the model
cfg.model.train_cfg = None
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
checkpoint = load_checkpoint(model, checkpoint_path, map_location='cpu')
if 'CLASSES' in checkpoint.get('meta', {}):
model.CLASSES = checkpoint['meta']['CLASSES']
else:
from mmdet.datasets import DATASETS
dataset = DATASETS.get(cfg.data.test['type'])
assert (dataset is not None)
model.CLASSES = dataset.CLASSES
model.cpu().eval()
return model
def preprocess_example_input(input_config):
"""Prepare an example input image for ``generate_inputs_and_wrap_model``.
Args:
input_config (dict): customized config describing the example input.
Returns:
tuple: (one_img, one_meta), tensor of the example input image and \
meta information for the example input image.
Examples:
>>> from mmdet.core.export import preprocess_example_input
>>> input_config = {
>>> 'input_shape': (1,3,224,224),
>>> 'input_path': 'demo/demo.jpg',
>>> 'normalize_cfg': {
>>> 'mean': (123.675, 116.28, 103.53),
>>> 'std': (58.395, 57.12, 57.375)
>>> }
>>> }
>>> one_img, one_meta = preprocess_example_input(input_config)
>>> print(one_img.shape)
torch.Size([1, 3, 224, 224])
>>> print(one_meta)
{'img_shape': (224, 224, 3),
'ori_shape': (224, 224, 3),
'pad_shape': (224, 224, 3),
'filename': '<demo>.png',
'scale_factor': 1.0,
'flip': False}
"""
input_path = input_config['input_path']
input_shape = input_config['input_shape']
one_img = mmcv.imread(input_path)
one_img = mmcv.imresize(one_img, input_shape[2:][::-1])
show_img = one_img.copy()
if 'normalize_cfg' in input_config.keys():
normalize_cfg = input_config['normalize_cfg']
mean = np.array(normalize_cfg['mean'], dtype=np.float32)
std = np.array(normalize_cfg['std'], dtype=np.float32)
to_rgb = normalize_cfg.get('to_rgb', True)
one_img = mmcv.imnormalize(one_img, mean, std, to_rgb=to_rgb)
one_img = one_img.transpose(2, 0, 1)
one_img = torch.from_numpy(one_img).unsqueeze(0).float().requires_grad_(
True)
(_, C, H, W) = input_shape
one_meta = {
'img_shape': (H, W, C),
'ori_shape': (H, W, C),
'pad_shape': (H, W, C),
'filename': '<demo>.png',
'scale_factor': np.ones(4),
'flip': False,
'show_img': show_img,
}
return one_img, one_meta

View File

@ -1,5 +1,5 @@
from .anchor_head import AnchorHead
from .anchor_head import anchor_head_get_bboxes
from .fsaf_head import fsaf_head_forward
from .rpn_head import rpn_head_forward
__all__ = ['AnchorHead', 'rpn_head_forward', 'fsaf_head_forward']
__all__ = ['anchor_head_get_bboxes', 'rpn_head_forward', 'fsaf_head_forward']

View File

@ -1,127 +1,110 @@
import torch
import torch.nn as nn
import mmdeploy
from mmdeploy.utils import MODULE_REWRITERS, is_dynamic_shape
from mmdeploy.utils import FUNCTION_REWRITERS, is_dynamic_shape
@MODULE_REWRITERS.register_rewrite_module(module_type='mmdet.models.AnchorHead'
)
@MODULE_REWRITERS.register_rewrite_module(module_type='mmdet.models.RetinaHead'
)
class AnchorHead(nn.Module):
@FUNCTION_REWRITERS.register_rewriter(
func_name='mmdet.models.AnchorHead.get_bboxes')
@FUNCTION_REWRITERS.register_rewriter(
func_name='mmdet.models.RetinaHead.get_bboxes')
def anchor_head_get_bboxes(rewriter,
self,
cls_scores,
bbox_preds,
img_shape,
with_nms=True,
cfg=None,
**kwargs):
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = rewriter.cfg
num_levels = len(cls_scores)
def __init__(self, module, cfg, **kwargs):
super(AnchorHead, self).__init__()
self.module = module
self.anchor_generator = self.module.anchor_generator
self.bbox_coder = self.module.bbox_coder
device = cls_scores[0].device
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device=device)
self.test_cfg = module.test_cfg
self.num_classes = module.num_classes
self.use_sigmoid_cls = module.use_sigmoid_cls
self.cls_out_channels = module.cls_out_channels
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
self.deploy_cfg = cfg
cfg = self.test_cfg if cfg is None else cfg
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(mlvl_anchors)
batch_size = mlvl_cls_scores[0].shape[0]
nms_pre = cfg.get('nms_pre', -1)
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
# loop over features, decode boxes
mlvl_bboxes = []
mlvl_scores = []
for level_id, cls_score, bbox_pred, anchors in zip(
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds, mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(0, 2, 3,
1).reshape(batch_size, -1,
self.cls_out_channels)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
def get_bboxes(self,
cls_scores,
bbox_preds,
img_shape,
with_nms=True,
cfg=None,
**kwargs):
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = self.deploy_cfg
num_levels = len(cls_scores)
# use static anchor if input shape is static
if not is_dynamic_shape(deploy_cfg):
anchors = anchors.data
device = cls_scores[0].device
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device=device)
anchors = anchors.expand_as(bbox_pred)
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
enable_nms_pre = True
backend = deploy_cfg['backend']
# topk in tensorrt does not support shape<k
# final level might meet the problem
if backend == 'tensorrt':
enable_nms_pre = (level_id != num_levels - 1)
cfg = self.test_cfg if cfg is None else cfg
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(
mlvl_anchors)
batch_size = mlvl_cls_scores[0].shape[0]
nms_pre = cfg.get('nms_pre', -1)
# loop over features, decode boxes
mlvl_bboxes = []
mlvl_scores = []
for level_id, cls_score, bbox_pred, anchors in zip(
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds,
mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(0, 2, 3,
1).reshape(batch_size, -1,
self.cls_out_channels)
if nms_pre > 0 and enable_nms_pre:
# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
max_scores, _ = scores.max(-1)
else:
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(batch_size, -1, 4)
# remind that we set FG labels to [0, num_class-1]
# since mmdet v2.0
# BG cat_id: num_class
max_scores, _ = scores[..., :-1].max(-1)
_, topk_inds = max_scores.topk(nms_pre)
batch_inds = torch.arange(batch_size).view(-1,
1).expand_as(topk_inds)
anchors = anchors[batch_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
# use static anchor if input shape is static
if not is_dynamic_shape(deploy_cfg):
anchors = anchors.data
if not is_dynamic_shape(deploy_cfg):
img_shape = [int(val) for val in img_shape]
anchors = anchors.expand_as(bbox_pred)
bboxes = self.bbox_coder.decode(
anchors, bbox_pred, max_shape=img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
enable_nms_pre = True
backend = deploy_cfg['backend']
# topk in tensorrt does not support shape<k
# final level might meet the problem
if backend == 'tensorrt':
enable_nms_pre = (level_id != num_levels - 1)
batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
if nms_pre > 0 and enable_nms_pre:
# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = scores.max(-1)
else:
# remind that we set FG labels to [0, num_class-1]
# since mmdet v2.0
# BG cat_id: num_class
max_scores, _ = scores[..., :-1].max(-1)
_, topk_inds = max_scores.topk(nms_pre)
batch_inds = torch.arange(batch_size).view(
-1, 1).expand_as(topk_inds)
anchors = anchors[batch_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
# ignore background class
if not self.use_sigmoid_cls:
batch_mlvl_scores = batch_mlvl_scores[..., :self.num_classes]
if not with_nms:
return batch_mlvl_bboxes, batch_mlvl_scores
bboxes = self.bbox_coder.decode(
anchors, bbox_pred, max_shape=img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
# ignore background class
if not self.use_sigmoid_cls:
batch_mlvl_scores = batch_mlvl_scores[..., :self.num_classes]
if not with_nms:
return batch_mlvl_bboxes, batch_mlvl_scores
max_output_boxes_per_class = cfg.nms.get('max_output_boxes_per_class',
200)
iou_threshold = cfg.nms.get('iou_threshold', 0.5)
score_threshold = cfg.score_thr
nms_pre = cfg.get('deploy_nms_pre', -1)
return mmdeploy.mmdet.core.export.add_dummy_nms_for_onnx(
batch_mlvl_bboxes,
batch_mlvl_scores,
max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=nms_pre,
after_top_k=cfg.max_per_img)
max_output_boxes_per_class = cfg.nms.get('max_output_boxes_per_class', 200)
iou_threshold = cfg.nms.get('iou_threshold', 0.5)
score_threshold = cfg.score_thr
nms_pre = cfg.get('deploy_nms_pre', -1)
return mmdeploy.mmdet.core.export.add_dummy_nms_for_onnx(
batch_mlvl_bboxes,
batch_mlvl_scores,
max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=nms_pre,
after_top_k=cfg.max_per_img)

View File

@ -1,4 +1,4 @@
from .single_stage import SingleStageDetector
from .single_stage import single_stage_forward
from .two_stage import extract_feat
__all__ = ['SingleStageDetector', 'extract_feat']
__all__ = ['single_stage_forward', 'extract_feat']

View File

@ -1,22 +1,22 @@
import torch
import torch.nn as nn
from mmdeploy.utils import MODULE_REWRITERS
from mmdeploy.utils import FUNCTION_REWRITERS, mark
@MODULE_REWRITERS.register_rewrite_module(module_type='mmdet.models.RetinaNet')
@MODULE_REWRITERS.register_rewrite_module(
module_type='mmdet.models.SingleStageDetector')
class SingleStageDetector(nn.Module):
@FUNCTION_REWRITERS.register_rewriter(
'mmdet.models.SingleStageDetector.extract_feat')
@mark('extract_feat')
def single_stage_extract_feat(rewriter, self, img):
return rewriter.origin_func(self, img)
def __init__(self, module, cfg, **kwargs):
super(SingleStageDetector, self).__init__()
self.module = module
self.bbox_head = module.bbox_head
def forward(self, data, **kwargs):
# get origin input shape to support onnx dynamic shape
img_shape = torch._shape_as_tensor(data)[2:]
x = self.module.extract_feat(data)
outs = self.bbox_head(x)
return self.bbox_head.get_bboxes(*outs, img_shape, **kwargs)
@FUNCTION_REWRITERS.register_rewriter(
func_name='mmdet.models.RetinaNet.forward')
@FUNCTION_REWRITERS.register_rewriter(
func_name='mmdet.models.SingleStageDetector.forward')
def single_stage_forward(rewriter, self, data, **kwargs):
# get origin input shape to support onnx dynamic shape
img_shape = torch._shape_as_tensor(data)[2:]
x = self.extract_feat(data)
outs = self.bbox_head(x)
return self.bbox_head.get_bboxes(*outs, img_shape, **kwargs)