[Feature] add yolox ncnn (#29)

* add yolox ncnn

* add ncnn android performance of yolox

* add ut

* fix lint

* fix None bugs for ncnn

* test codecov

* test codecov

* add device

* fix yapf

* remove if-else for img shape

* use channelshuffle optimize

* change benchmark after channelshuffle

* fix yapf

* fix yapf

* fuse continuous reshape

* fix static shape deploy

* fix code

* drop pad

* only static shape

* fix static

* fix docstring
pull/77/head
hanrui1sensetime 2022-01-19 13:54:45 +08:00 committed by GitHub
parent 997d111a6f
commit e6e32a9db4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 321 additions and 10 deletions

View File

@ -0,0 +1,4 @@
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ncnn.py']
codebase_config = dict(model_type='ncnn_end2end')
onnx_config = dict(output_names=['detection_output'], input_shape=[416, 416])

View File

@ -319,6 +319,16 @@ Users can directly test the speed through [how_to_measure_performance_of_models.
<td align="center">15.11</td>
<td>$MMDET_DIR/configs/ssd/ssdlite_mobilenetv2_scratch_600e_coco.py</td>
</tr>
<tr>
<td align="center">YOLOX</td>
<td align="center">COCO</td>
<td align="center">1x3x416x416</td>
<td align="center">111.60</td>
<td align="center">8.96</td>
<td align="center">134.50</td>
<td align="center">7.43</td>
<td>$MMDET_DIR/configs/yolox/yolox_tiny_8x8_300e_coco.py</td>
</tr>
</tbody>
</table>
</div>

View File

@ -18,7 +18,7 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmdetection/blob/
| SSD | ObjectDetection | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) |
| VFNet | ObjectDetection | N | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) |
| YOLOv3 | ObjectDetection | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolo) |
| YOLOX | ObjectDetection | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox) |
| YOLOX | ObjectDetection | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox) |
| Cascade R-CNN | ObjectDetection | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
| Faster R-CNN | ObjectDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |
| Faster R-CNN + DCN | ObjectDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |

View File

@ -7,7 +7,7 @@ The table below lists the models that are guaranteed to be exportable to other b
| RetinaNet | MMDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) |
| Faster R-CNN | MMDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |
| YOLOv3 | MMDetection | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolo) |
| YOLOX | MMDetection | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox) |
| YOLOX | MMDetection | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox) |
| FCOS | MMDetection | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fcos) |
| FSAF | MMDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) |
| Mask R-CNN | MMDetection | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) |
@ -44,3 +44,4 @@ The table below lists the models that are guaranteed to be exportable to other b
- Tag:
- static: This model only support static export. Please use `static` deploy config, just like $MMDEPLOY_DIR/configs/mmseg/segmentation_tensorrt_static-1024x2048.py.
- SSD: When you convert SSD model, you need to use min shape deploy config just like 300x300-512x512 rather than 320x320-1344x1344, for example $MMDEPLOY_DIR/configs/mmdet/detection/detection_tensorrt_dynamic-300x300-512x512.py.
- YOLOX: YOLOX with ncnn only supports static shape.

View File

@ -3,12 +3,13 @@
### 后端
CPU: ncnn, ONNXRuntime, OpenVINO
GPU: TensorRT, PPLNN
GPU: ncnn, TensorRT, PPLNN
### 延迟基准
#### 平台
- Ubuntu 18.04 操作系统
- ncnn 20211208
- Cuda 11.3
- TensorRT 7.2.3.4
- Docker 20.10.8
@ -19,7 +20,7 @@ GPU: TensorRT, PPLNN
- 批次大小为 1
- 每次推理后均同步
- 延迟基准测试时我们计算各个数据集中100张图片的平均延时。
- 热身。 针对分类任务我们热身1010轮。 对其他任务我们热身10轮。
- 热身。 针对ncnn后端我们热身30轮; 对于其他后端:针对分类任务我们热身1010轮对其他任务我们热身10轮。
- 输入分辨率根据代码库的数据集不同而不同,除了`mmediting`,其他代码库均使用真实图片作为输入。
@ -319,6 +320,16 @@ GPU: TensorRT, PPLNN
<td align="center">15.11</td>
<td>$MMDET_DIR/configs/ssd/ssdlite_mobilenetv2_scratch_600e_coco.py</td>
</tr>
<tr>
<td align="center">YOLOX</td>
<td align="center">COCO</td>
<td align="center">1x3x416x416</td>
<td align="center">111.60</td>
<td align="center">8.96</td>
<td align="center">134.50</td>
<td align="center">7.43</td>
<td>$MMDET_DIR/configs/yolox/yolox_tiny_8x8_300e_coco.py</td>
</tr>
</tbody>
</table>
</div>

View File

@ -109,14 +109,17 @@ class NCNNWrapper(BaseWrapper):
mat = result[name]
# deal with special case
if mat.empty():
mat = None
logger.warning(
f'The "{name}" output of ncnn model is empty.')
continue
outputs[name][batch_id] = torch.from_numpy(np.array(mat))
# stack outputs together
for name, output_tensor in outputs.items():
outputs[name] = torch.stack(output_tensor)
if None in output_tensor:
outputs[name] = None
else:
outputs[name] = torch.stack(output_tensor)
return outputs

View File

@ -8,7 +8,7 @@ from mmcv.parallel import DataContainer
from torch.utils.data import Dataset
from mmdeploy.utils import Task, get_root_logger
from mmdeploy.utils.config_utils import get_input_shape
from mmdeploy.utils.config_utils import get_input_shape, is_dynamic_shape
from ...base import BaseTask
from .mmdetection import MMDET_TASK
@ -114,7 +114,19 @@ class ObjectDetection(BaseTask):
from mmcv.parallel import collate, scatter
if not isinstance(imgs, (list, tuple)):
imgs = [imgs]
dynamic_flag = is_dynamic_shape(self.deploy_cfg)
cfg = process_model_config(self.model_cfg, imgs, input_shape)
# Drop pad_to_square when static shape. Because static shape should
# ensure the shape before input image.
if not dynamic_flag:
transform = cfg.data.test.pipeline[1]
if 'transforms' in transform:
transform_list = transform['transforms']
for i, step in enumerate(transform_list):
if step['type'] == 'Pad' and 'pad_to_square' in step \
and step['pad_to_square']:
transform_list.pop(i)
break
test_pipeline = Compose(cfg.data.test.pipeline)
data_list = []
for img in imgs:

View File

@ -578,7 +578,7 @@ class NCNNEnd2EndModel(End2EndModel):
outputs = self.wrapper({self.input_name: imgs})
for key, item in outputs.items():
if item is None:
return [np.zeros((1, 0, 6))]
return [np.zeros((1, 0, 5)), np.zeros((1, 0))]
out = self.wrapper.output_to_list(outputs)[0]
labels = out[:, :, 0] - 1
scales = torch.tensor([W, H, W, H]).reshape(1, 1, 4)

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backbones import * # noqa: F401, F403
from .dense_heads import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403
from .necks import * # noqa: F401,F403

View File

@ -0,0 +1,45 @@
import torch
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.csp_darknet.Focus.forward',
backend='ncnn')
def focus__forward__ncnn(ctx, self, x):
"""Rewrite forward function of Focus class for ncnn.
Focus width and height information into channel space. NCNN does not
support slice operator which step greater than 1, so we use another
way to implement.
Args:
x (Tensor): The input tensor with shape (N, C, H, W).
Returns:
x (Tensor): The calculated tensor with shape (N, 4*C, H//2, W//2).
"""
batch_size, c, h, w = x.shape
assert h % 2 == 0 and w % 2 == 0, f'focus for yolox needs even feature\
height and width, got {(h, w)}.'
x = x.reshape(batch_size, c * h, 1, w)
_b, _c, _h, _w = x.shape
g = _c // 2
# fuse to ncnn's shufflechannel
x = x.view(_b, g, 2, _h, _w)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(_b, -1, _h, _w)
x = x.reshape(_b, c * h * w, 1, 1)
_b, _c, _h, _w = x.shape
g = _c // 2
# fuse to ncnn's shufflechannel
x = x.view(_b, g, 2, _h, _w)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(_b, -1, _h, _w)
x = x.reshape(_b, c * 4, h // 2, w // 2)
return self.conv(x)

View File

@ -5,12 +5,12 @@ from .fovea_head import fovea_head__get_bboxes
from .rpn_head import rpn_head__get_bboxes, rpn_head__get_bboxes__ncnn
from .ssd_head import ssd_head__get_bboxes__ncnn
from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn
from .yolox_head import yolox_head__get_bboxes
from .yolox_head import yolox_head__get_bboxes, yolox_head__get_bboxes__ncnn
__all__ = [
'rpn_head__get_bboxes', 'rpn_head__get_bboxes__ncnn',
'yolov3_head__get_bboxes', 'yolov3_head__get_bboxes__ncnn',
'yolox_head__get_bboxes', 'base_dense_head__get_bbox',
'fovea_head__get_bboxes', 'base_dense_head__get_bboxes__ncnn',
'ssd_head__get_bboxes__ncnn'
'ssd_head__get_bboxes__ncnn', 'yolox_head__get_bboxes__ncnn'
]

View File

@ -91,3 +91,127 @@ def yolox_head__get_bboxes(ctx,
return multiclass_nms(bboxes, scores, max_output_boxes_per_class,
iou_threshold, score_threshold, pre_top_k,
keep_top_k)
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.YOLOXHead.get_bboxes', backend='ncnn')
def yolox_head__get_bboxes__ncnn(ctx,
self,
cls_scores,
bbox_preds,
objectnesses,
img_metas=None,
cfg=None,
rescale=False,
with_nms=True):
"""Rewrite `get_bboxes` of YOLOXHead for ncnn backend.
1. Decode the prior to a box format for ncnn DetectionOutput layer to do
the post-processing.
2. Batch dimension is not supported by ncnn, but supported by pytorch.
The negative value of axis in torch.cat is rewritten as corresponding
positive value to avoid axis shift.
3. 2-dimension tensor broadcast of `BinaryOps` operator is not supported by
ncnn. This function unsqueeze 2-dimension tensor to 3-dimension tensor for
correct `BinaryOps` calculation by ncnn.
Args:
ctx: Context that contains original meta information.
self: Represent the instance of the original class.
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
objectnesses (list[Tensor], Optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
img_metas (list[dict]): Image meta info. Default None.
cfg (mmcv.Config, Optional): Test / postprocessing configuration,
if None, test_cfg would be used. Default None.
rescale (bool): If True, return boxes in original image space.
Default False.
with_nms (bool): If True, do nms before return boxes.
Default True.
Returns:
output__ncnn (Tensor): outputs, shape is [N, num_det, 6].
"""
from mmdeploy.codebase.mmdet.core.ops import ncnn_detection_output_forward
from mmdeploy.utils.config_utils import is_dynamic_shape
from mmdeploy.utils import get_root_logger
dynamic_flag = is_dynamic_shape(ctx.cfg)
if dynamic_flag:
logger = get_root_logger()
logger.warning('YOLOX does not support dynamic shape with ncnn.')
img_height = int(img_metas[0]['img_shape'][0])
img_width = int(img_metas[0]['img_shape'][1])
assert len(cls_scores) == len(bbox_preds) == len(objectnesses)
device = cls_scores[0].device
cfg = self.test_cfg if cfg is None else cfg
batch_size = bbox_preds[0].shape[0]
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes, device=device, with_stride=True)
mlvl_priors = [mlvl_prior.unsqueeze(0) for mlvl_prior in mlvl_priors]
flatten_priors = torch.cat(mlvl_priors, dim=1)
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
self.cls_out_channels)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
for bbox_pred in bbox_preds
]
flatten_objectness = [
objectness.permute(0, 2, 3, 1).reshape(batch_size, -1, 1)
for objectness in objectnesses
]
cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
dummy_cls_scores = torch.zeros(
batch_size, cls_scores.shape[-2], 1, device=cls_scores.device)
batch_mlvl_scores = torch.cat([dummy_cls_scores, cls_scores], dim=2)
score_factor = torch.cat(flatten_objectness, dim=1).sigmoid()
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
assert flatten_priors.shape[-1] == 4, f'yolox needs (B, N, 4) priors, got\
(B, N, {flatten_priors.shape[-1]})'
prior_box_x1 = (flatten_priors[:, :, 0:1] - flatten_priors[:, :, 2:3] / 2)\
/ img_width
prior_box_y1 = (flatten_priors[:, :, 1:2] - flatten_priors[:, :, 3:4] / 2)\
/ img_height
prior_box_x2 = (flatten_priors[:, :, 0:1] + flatten_priors[:, :, 2:3] / 2)\
/ img_width
prior_box_y2 = (flatten_priors[:, :, 1:2] + flatten_priors[:, :, 3:4] / 2)\
/ img_height
prior_box_ncnn = torch.cat(
[prior_box_x1, prior_box_y1, prior_box_x2, prior_box_y2], dim=2)
scores = batch_mlvl_scores.permute(0, 2, 1).unsqueeze(3) * \
score_factor.permute(0, 2, 1).unsqueeze(3)
scores = scores.squeeze(3).permute(0, 2, 1)
batch_mlvl_bboxes = flatten_bbox_preds.reshape(batch_size, 1, -1)
batch_mlvl_scores = scores.reshape(batch_size, 1, -1)
batch_mlvl_priors = prior_box_ncnn.reshape(batch_size, 1, -1)
batch_mlvl_vars = torch.ones_like(batch_mlvl_priors)
batch_mlvl_priors = torch.cat([batch_mlvl_priors, batch_mlvl_vars], dim=1)
deploy_cfg = ctx.cfg
post_params = get_post_processing_params(deploy_cfg)
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
vars = torch.tensor([1, 1, 1, 1], dtype=torch.float32)
output__ncnn = ncnn_detection_output_forward(
batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_priors,
score_threshold, iou_threshold, pre_top_k, keep_top_k,
self.num_classes + 1,
vars.cpu().detach().numpy())
return output__ncnn

View File

@ -112,6 +112,15 @@ def get_fcos_head_model():
return model
def get_focus_backbone_model():
"""Backbone Focus Config."""
from mmdet.models.backbones.csp_darknet import Focus
model = Focus(3, 32)
model.requires_grad_(False)
return model
def get_l2norm_forward_model():
"""L2Norm Neck Config."""
from mmdet.models.necks.ssd_neck import L2Norm
@ -148,6 +157,35 @@ def get_single_roi_extractor():
return model
def test_focus_forward_ncnn():
backend_type = Backend.NCNN
check_backend(backend_type)
focus_model = get_focus_backbone_model()
focus_model.cpu().eval()
s = 128
seed_everything(1234)
x = torch.rand(1, 3, s, s)
model_outputs = [focus_model.forward(x)]
wrapped_model = WrapModel(focus_model, 'forward')
rewrite_inputs = {
'x': x,
}
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(input_shape=None)))
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs[0]):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_l2norm_forward(backend_type):
check_backend(backend_type)
@ -930,6 +968,68 @@ def test_yolox_head_get_bboxes(backend_type: Backend):
assert rewrite_outputs is not None
def test_yolox_head_get_bboxes_ncnn():
"""Test get_bboxes rewrite of yolox head for ncnn."""
backend_type = Backend.NCNN
check_backend(backend_type)
yolox_head = get_yolox_head_model()
yolox_head.cpu().eval()
s = 128
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
output_names = ['detection_output']
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=20,
pre_top_k=5000,
keep_top_k=10,
background_label_id=0,
))))
seed_everything(1234)
cls_scores = [
torch.rand(1, yolox_head.num_classes, pow(2, i), pow(2, i))
for i in range(3, 0, -1)
]
seed_everything(5678)
bbox_preds = [
torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(3, 0, -1)
]
seed_everything(9101)
objectnesses = [
torch.rand(1, 1, pow(2, i), pow(2, i)) for i in range(3, 0, -1)
]
# to get outputs of onnx model after rewrite
wrapped_model = WrapModel(yolox_head, 'get_bboxes', img_metas=img_metas)
rewrite_inputs = {
'cls_scores': cls_scores,
'bbox_preds': bbox_preds,
'objectnesses': objectnesses,
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
# output should be of shape [1, N, 6]
if is_backend_output:
assert rewrite_outputs[0].shape[-1] == 6
else:
assert rewrite_outputs.shape[-1] == 6
def get_vfnet_head_model():
"""VFNet Head Config."""
test_cfg = mmcv.Config(