[Feature]: Support yolov3 (#167)

* support yolov3 with ort and trt

* add ncnn compare

* fix yolo_head ncnn rewriter

* align perforance with ort trt for yolov3

* update doc

* add test for compare with equal,less, greater

* change namespace

* reformat cpp

* fix lint

* fix lint

* add unit test for yolov3 head

* remove compare op

* update doc

* update table format in docs

* update comments

* update

Co-authored-by: hanrui1sensetime <hanrui1@sensetime.com>
This commit is contained in:
RunningLeon 2021-11-08 16:07:58 +08:00 committed by GitHub
parent 2ffc657665
commit cd51f12f32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 469 additions and 48 deletions

View File

@ -3403,8 +3403,8 @@ int main(int argc, char** argv) {
fprintf(pp, "%-16s", "LayerNorm"); fprintf(pp, "%-16s", "LayerNorm");
} else if (op == "LeakyRelu") { } else if (op == "LeakyRelu") {
fprintf(pp, "%-16s", "ReLU"); fprintf(pp, "%-16s", "ReLU");
} else if (op == "Less") { } else if (op == "Threshold") {
fprintf(pp, "%-16s", "Compare"); fprintf(pp, "%-16s", "Threshold");
} else if (op == "Log") { } else if (op == "Log") {
fprintf(pp, "%-16s", "UnaryOp"); fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "LRN") { } else if (op == "LRN") {
@ -4315,11 +4315,10 @@ int main(int argc, char** argv) {
} }
} else if (op == "LeakyRelu") { } else if (op == "LeakyRelu") {
float alpha = get_node_attr_f(node, "alpha", 0.01f); float alpha = get_node_attr_f(node, "alpha", 0.01f);
fprintf(pp, " 0=%e", alpha); fprintf(pp, " 0=%e", alpha);
} else if (op == "Less") { } else if (op == "Threshold") {
int op_type = 1; float threshold = get_node_attr_f(node, "threshold", 0.f);
fprintf(pp, " 0=%d", op_type); fprintf(pp, " 0=%e", threshold);
} else if (op == "Log") { } else if (op == "Log") {
int op_type = 8; int op_type = 8;
fprintf(pp, " 0=%d", op_type); fprintf(pp, " 0=%d", op_type);

View File

@ -6,6 +6,7 @@ codebase_config = dict(
task='ObjectDetection', task='ObjectDetection',
post_processing=dict( post_processing=dict(
score_threshold=0.05, score_threshold=0.05,
confidence_threshold=0.005, # for YOLOv3
iou_threshold=0.5, iou_threshold=0.5,
max_output_boxes_per_class=200, max_output_boxes_per_class=200,
pre_top_k=-1, pre_top_k=-1,

View File

@ -0,0 +1,12 @@
_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/tensorrt.py']
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 160, 160],
opt_shape=[1, 3, 608, 608],
max_shape=[1, 3, 608, 608])))
])

View File

@ -2,15 +2,16 @@
<!-- TOC --> <!-- TOC -->
- [Tutorial : How to convert model](#how-to-convert-model) - [How to convert model](#how-to-convert-model)
- [How to convert models from Pytorch to BACKEND](#how-to-convert-models-from-pytorch-to-other-backends) - [How to convert models from Pytorch to other backends](#how-to-convert-models-from-pytorch-to-other-backends)
- [Prerequisite](#prerequisite) - [Prerequisite](#prerequisite)
- [Usage](#usage) - [Usage](#usage)
- [Description of all arguments](#description-of-all-arguments) - [Description of all arguments](#description-of-all-arguments)
- [How to evaluate the exported models](#how-to-evaluate-the-exported-models) - [Example](#example)
- [List of supported models exportable to BACKEND](#list-of-supported-models-exportable-to-other-backends) - [How to evaluate the exported models](#how-to-evaluate-the-exported-models)
- [Reminders](#reminders) - [List of supported models exportable to other backends](#list-of-supported-models-exportable-to-other-backends)
- [FAQs](#faqs) - [Reminders](#reminders)
- [FAQs](#faqs)
<!-- TOC --> <!-- TOC -->
@ -78,33 +79,33 @@ You can try to evaluate model, referring to [how_to_evaluate_a_model](./how_to_e
The table below lists the models that are guaranteed to be exportable to other backend. The table below lists the models that are guaranteed to be exportable to other backend.
| Model | codebase | model config file(example) | OnnxRuntime | TensorRT | NCNN | PPL | OpenVINO | | Model | codebase | OnnxRuntime | TensorRT | NCNN | PPL | OpenVINO | model config file(example) |
| :----------------: | :--------------: | :---------------------------------------------------------------------------------------: | :---------: | :-----------: | :---:| :---: | :-------: | |--------------------|------------------|:-----------:|:--------:|:----:|:---:|:--------:|:--------------------------------------------------------------------------------------|
| RetinaNet | MMDetection | $PATH_TO_MMDET/configs/retinanet/retinanet_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y | | RetinaNet | MMDetection | Y | Y | Y | Y | Y | $MMDET_DIR/configs/retinanet/retinanet_r50_fpn_1x_coco.py |
| Faster R-CNN | MMDetection | $PATH_TO_MMDET/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y | | Faster R-CNN | MMDetection | Y | Y | Y | Y | Y | $MMDET_DIR/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py |
| YOLOv3 | MMDetection | $PATH_TO_MMDET/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py | Y | Y | N | Y | N | | YOLOv3 | MMDetection | Y | Y | Y | Y | Y | $MMDET_DIR/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py |
| FCOS | MMDetection | $PATH_TO_MMDET/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py | Y | Y | Y | N | Y | | FCOS | MMDetection | Y | Y | Y | N | Y | $MMDET_DIR/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py |
| FSAF | MMDetection | $PATH_TO_MMDET/configs/fsaf/fsaf_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y | | FSAF | MMDetection | Y | Y | Y | Y | Y | $MMDET_DIR/configs/fsaf/fsaf_r50_fpn_1x_coco.py |
| Mask R-CNN | MMDetection | $PATH_TO_MMDET/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py | Y | Y | N | Y | Y | | Mask R-CNN | MMDetection | Y | Y | N | Y | Y | $MMDET_DIR/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py |
| SSD | MMDetection | $PATH_TO_MMDET/configs/ssd/ssd300_coco.py | Y | ? | ? | ? | Y | | SSD | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/ssd/ssd300_coco.py |
| FoveaBox | MMDetection | $PATH_TO_MMDET/configs/foveabox/fovea_r50_fpn_4x4_1x_coco.py | Y | ? | ? | ? | Y | | FoveaBox | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/foveabox/fovea_r50_fpn_4x4_1x_coco.py |
| ATSS | MMDetection | $PATH_TO_MMDET/configs/atss/atss_r50_fpn_1x_coco.py | Y | ? | ? | ? | Y | | ATSS | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/atss/atss_r50_fpn_1x_coco.py |
| Cascade R-CNN | MMDetection | $PATH_TO_MMDET/configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py | Y | ? | ? | ? | Y | | Cascade R-CNN | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py |
| Cascade Mask R-CNN | MMDetection | $PATH_TO_MMDET/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py | Y | ? | ? | ? | Y | | Cascade Mask R-CNN | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py |
| ResNet | MMClassification | $PATH_TO_MMCLS/configs/resnet/resnet18_b32x8_imagenet.py | Y | Y | Y | Y | N | | ResNet | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/resnet/resnet18_b32x8_imagenet.py |
| ResNeXt | MMClassification | $PATH_TO_MMCLS/configs/resnext/resnext50_32x4d_b32x8_imagenet.py | Y | Y | Y | Y | N | | ResNeXt | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/resnext/resnext50_32x4d_b32x8_imagenet.py |
| SE-ResNet | MMClassification | $PATH_TO_MMCLS/configs/seresnet/seresnet50_b32x8_imagenet.py | Y | Y | Y | Y | N | | SE-ResNet | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/seresnet/seresnet50_b32x8_imagenet.py |
| MobileNetV2 | MMClassification | $PATH_TO_MMCLS/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py | Y | Y | Y | Y | N | | MobileNetV2 | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py |
| ShuffleNetV1 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y | N | | ShuffleNetV1 | MMClassification | Y | Y | N | Y | N | $MMCLS_DIR/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py |
| ShuffleNetV2 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y | N | | ShuffleNetV2 | MMClassification | Y | Y | N | Y | N | $MMCLS_DIR/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py |
| FCN | MMSegmentation | $PATH_TO_MMSEG/configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N | | FCN | MMSegmentation | Y | Y | Y | Y | N | $MMSEG_DIR/configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py |
| PSPNet | MMSegmentation | $PATH_TO_MMSEG/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py | Y | Y | N | Y | N | | PSPNet | MMSegmentation | Y | Y | N | Y | N | $MMSEG_DIR/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py |
| DeepLabV3 | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N | | DeepLabV3 | MMSegmentation | Y | Y | Y | Y | N | $MMSEG_DIR/configs/deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py |
| DeepLabV3+ | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N | | DeepLabV3+ | MMSegmentation | Y | Y | Y | Y | N | $MMSEG_DIR/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py |
| SRCNN | MMEditing | $PATH_TO_MMSEG/configs/restorers/srcnn/srcnn_x4k915_g1_1000k_div2k.py | Y | Y | N | Y | N | | SRCNN | MMEditing | Y | Y | N | Y | N | $MMSEG_DIR/configs/restorers/srcnn/srcnn_x4k915_g1_1000k_div2k.py |
| ESRGAN | MMEditing | $PATH_TO_MMSEG/configs/restorers/esrgan/esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py | Y | Y | N | Y | N | | ESRGAN | MMEditing | Y | Y | N | Y | N | $MMSEG_DIR/configs/restorers/esrgan/esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py |
| DBNet | MMOCR | $PATH_TO_MMOCR/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py | Y | Y | Y | Y | N | | DBNet | MMOCR | Y | Y | Y | Y | N | $MMOCR_DIR/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py |
| CRNN | MMOCR | $PATH_TO_MMOCR/configs/textrecog/tps/crnn_tps_academic_dataset.py | Y | Y | Y | N | N | | CRNN | MMOCR | Y | Y | Y | N | N | $MMOCR_DIR/configs/textrecog/tps/crnn_tps_academic_dataset.py |
### Reminders ### Reminders

View File

@ -20,7 +20,8 @@ class ONNXNMSop(torch.autograd.Function):
for batch_id in range(batch_size): for batch_id in range(batch_size):
for cls_id in range(num_class): for cls_id in range(num_class):
_boxes = boxes[batch_id, ...] _boxes = boxes[batch_id, ...]
_scores = scores[batch_id, cls_id, ...] # score_threshold=0 requires scores to be contiguous
_scores = scores[batch_id, cls_id, ...].contiguous()
_, box_inds = nms( _, box_inds = nms(
_boxes, _boxes,
_scores, _scores,

View File

@ -3,9 +3,11 @@ from .atss_head import get_bboxes_of_atss_head
from .fcos_head import get_bboxes_of_fcos_head from .fcos_head import get_bboxes_of_fcos_head
from .fovea_head import get_bboxes_of_fovea_head from .fovea_head import get_bboxes_of_fovea_head
from .rpn_head import get_bboxes_of_rpn_head from .rpn_head import get_bboxes_of_rpn_head
from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn
__all__ = [ __all__ = [
'get_bboxes_of_anchor_head', 'get_bboxes_of_fcos_head', 'get_bboxes_of_anchor_head', 'get_bboxes_of_fcos_head',
'get_bboxes_of_rpn_head', 'get_bboxes_of_fovea_head', 'get_bboxes_of_rpn_head', 'get_bboxes_of_fovea_head',
'get_bboxes_of_atss_head' 'get_bboxes_of_atss_head', 'yolov3_head__get_bboxes',
'yolov3_head__get_bboxes__ncnn'
] ]

View File

@ -102,10 +102,11 @@ def get_bboxes_of_fcos_head(ctx,
max_output_boxes_per_class = post_params.max_output_boxes_per_class max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold) score_threshold = cfg.get('score_thr', post_params.score_threshold)
nms_pre = cfg.get('deploy_nms_pre', -1) pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
return multiclass_nms(batch_mlvl_bboxes, batch_mlvl_scores, return multiclass_nms(batch_mlvl_bboxes, batch_mlvl_scores,
max_output_boxes_per_class, iou_threshold, max_output_boxes_per_class, iou_threshold,
score_threshold, nms_pre, cfg.max_per_img) score_threshold, pre_top_k, keep_top_k)
@FUNCTION_REWRITER.register_rewriter( @FUNCTION_REWRITER.register_rewriter(
@ -195,7 +196,8 @@ def get_bboxes_of_fcos_head_ncnn(ctx,
max_output_boxes_per_class = post_params.max_output_boxes_per_class max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold) score_threshold = cfg.get('score_thr', post_params.score_threshold)
nms_pre = cfg.get('deploy_nms_pre', -1) pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
return multiclass_nms(batch_mlvl_bboxes, batch_mlvl_scores, return multiclass_nms(batch_mlvl_bboxes, batch_mlvl_scores,
max_output_boxes_per_class, iou_threshold, max_output_boxes_per_class, iou_threshold,
score_threshold, nms_pre, cfg.max_per_img) score_threshold, pre_top_k, keep_top_k)

View File

@ -0,0 +1,309 @@
import torch
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.mmdet.core import multiclass_nms
from mmdeploy.mmdet.export import pad_with_value
from mmdeploy.utils import (Backend, get_backend, get_mmdet_params,
is_dynamic_shape)
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.YOLOV3Head.get_bboxes')
def yolov3_head__get_bboxes(ctx,
self,
pred_maps,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` for default backend.
Transform network output for a batch into bbox predictions.
Args:
ctx: Context that contains original meta information.
self: Represent the instance of the original class.
pred_maps (list[Tensor]): Raw predictions for a batch of images.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor,
where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch
size and the score between 0 and 1. The shape of the second
tensor in the tuple is (N, num_box), and each element
represents the class label of the corresponding box.
"""
is_dynamic_flag = is_dynamic_shape(ctx.cfg)
num_levels = len(pred_maps)
pred_maps_list = [pred_maps[i].detach() for i in range(num_levels)]
cfg = self.test_cfg if cfg is None else cfg
assert len(pred_maps_list) == self.num_levels
device = pred_maps_list[0].device
batch_size = pred_maps_list[0].shape[0]
featmap_sizes = [
pred_maps_list[i].shape[-2:] for i in range(self.num_levels)
]
multi_lvl_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device)
pre_topk = cfg.get('nms_pre', -1)
multi_lvl_bboxes = []
multi_lvl_cls_scores = []
multi_lvl_conf_scores = []
for i in range(self.num_levels):
# get some key info for current scale
pred_map = pred_maps_list[i]
stride = self.featmap_strides[i]
# (b,h, w, num_anchors*num_attrib) ->
# (b,h*w*num_anchors, num_attrib)
pred_map = pred_map.permute(0, 2, 3,
1).reshape(batch_size, -1, self.num_attrib)
# Inplace operation like
# ```pred_map[..., :2] = \torch.sigmoid(pred_map[..., :2])```
# would create constant tensor when exporting to onnx
pred_map_conf = torch.sigmoid(pred_map[..., :2])
pred_map_rest = pred_map[..., 2:]
pred_map = torch.cat([pred_map_conf, pred_map_rest], dim=-1)
pred_map_boxes = pred_map[..., :4]
multi_lvl_anchor = multi_lvl_anchors[i]
# use static anchor if input shape is static
if not is_dynamic_flag:
multi_lvl_anchor = multi_lvl_anchor.data
multi_lvl_anchor = multi_lvl_anchor.unsqueeze(0).expand_as(
pred_map_boxes)
bbox_pred = self.bbox_coder.decode(multi_lvl_anchor, pred_map_boxes,
stride)
# conf and cls
conf_pred = torch.sigmoid(pred_map[..., 4])
cls_pred = torch.sigmoid(pred_map[..., 5:]).view(
batch_size, -1, self.num_classes) # Cls pred one-hot.
backend = get_backend(ctx.cfg)
# topk in tensorrt does not support shape<k
# concate zero to enable topk,
if backend == Backend.TENSORRT:
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
conf_pred = pad_with_value(conf_pred, 1, pre_topk, 0.)
cls_pred = pad_with_value(cls_pred, 1, pre_topk, 0.)
if pre_topk > 0:
_, topk_inds = conf_pred.topk(pre_topk)
batch_inds = torch.arange(
batch_size, device=device).view(-1,
1).expand_as(topk_inds).long()
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
transformed_inds = (bbox_pred.shape[1] * batch_inds + topk_inds)
bbox_pred = bbox_pred.reshape(-1, 4)[transformed_inds, :].reshape(
batch_size, -1, 4)
cls_pred = cls_pred.reshape(
-1, self.num_classes)[transformed_inds, :].reshape(
batch_size, -1, self.num_classes)
conf_pred = conf_pred.reshape(-1, 1)[transformed_inds].reshape(
batch_size, -1)
# Save the result of current scale
multi_lvl_bboxes.append(bbox_pred)
multi_lvl_cls_scores.append(cls_pred)
multi_lvl_conf_scores.append(conf_pred)
# Merge the results of different scales together
batch_mlvl_bboxes = torch.cat(multi_lvl_bboxes, dim=1)
batch_mlvl_scores = torch.cat(multi_lvl_cls_scores, dim=1)
batch_mlvl_conf_scores = torch.cat(multi_lvl_conf_scores, dim=1)
post_params = get_mmdet_params(ctx.cfg)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
confidence_threshold = cfg.get('conf_thr',
post_params.confidence_threshold)
# follow original pipeline of YOLOv3
if confidence_threshold > 0:
mask = (batch_mlvl_conf_scores >= confidence_threshold).float()
batch_mlvl_conf_scores *= mask
if score_threshold > 0:
mask = (batch_mlvl_scores > score_threshold).float()
batch_mlvl_scores *= mask
batch_mlvl_conf_scores = batch_mlvl_conf_scores.unsqueeze(2).expand_as(
batch_mlvl_scores)
batch_mlvl_scores = batch_mlvl_scores * batch_mlvl_conf_scores
if with_nms:
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
# keep aligned with original pipeline, improve
# mAP by 1% for YOLOv3 in ONNX
score_threshold = 0
return multiclass_nms(
batch_mlvl_bboxes,
batch_mlvl_scores,
max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)
else:
return batch_mlvl_bboxes, batch_mlvl_scores
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.YOLOV3Head.get_bboxes', backend='ncnn')
def yolov3_head__get_bboxes__ncnn(ctx,
self,
pred_maps,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` for ncnn backend.
Transform network output for a batch into bbox predictions.
Args:
ctx: Context that contains original meta information.
self: Represent the instance of the original class.
pred_maps (list[Tensor]): Raw predictions for a batch of images.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor,
where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch
size and the score between 0 and 1. The shape of the second
tensor in the tuple is (N, num_box), and each element
represents the class label of the corresponding box.
"""
num_levels = len(pred_maps)
pred_maps_list = [pred_maps[i].detach() for i in range(num_levels)]
cfg = self.test_cfg if cfg is None else cfg
assert len(pred_maps_list) == self.num_levels
device = pred_maps_list[0].device
batch_size = pred_maps_list[0].shape[0]
featmap_sizes = [
pred_maps_list[i].shape[-2:] for i in range(self.num_levels)
]
multi_lvl_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device)
pre_topk = cfg.get('nms_pre', -1)
multi_lvl_bboxes = []
multi_lvl_cls_scores = []
multi_lvl_conf_scores = []
for i in range(self.num_levels):
# get some key info for current scale
pred_map = pred_maps_list[i]
stride = self.featmap_strides[i]
# (b,h, w, num_anchors*num_attrib) ->
# (b,h*w*num_anchors, num_attrib)
pred_map = pred_map.permute(0, 2, 3,
1).reshape(batch_size, -1, self.num_attrib)
# Inplace operation like
# ```pred_map[..., :2] = \torch.sigmoid(pred_map[..., :2])```
# would create constant tensor when exporting to onnx
pred_map_conf = torch.sigmoid(pred_map[..., :2])
pred_map_rest = pred_map[..., 2:]
# dim must be written as 2, but not -1, because ncnn implicit batch
# mechanism.
pred_map = torch.cat([pred_map_conf, pred_map_rest], dim=2)
pred_map_boxes = pred_map[..., :4]
multi_lvl_anchor = multi_lvl_anchors[i]
# use static anchor if input shape is static
multi_lvl_anchor = multi_lvl_anchor.unsqueeze(0).expand_as(
pred_map_boxes).data
bbox_pred = self.bbox_coder.decode(multi_lvl_anchor, pred_map_boxes,
stride)
# conf and cls
conf_pred = torch.sigmoid(pred_map[..., 4])
cls_pred = torch.sigmoid(pred_map[..., 5:]).view(
batch_size, -1, self.num_classes) # Cls pred one-hot.
if pre_topk > 0:
_, topk_inds = conf_pred.topk(pre_topk)
topk_inds = topk_inds.view(-1)
bbox_pred = bbox_pred[:, topk_inds, :]
cls_pred = cls_pred[:, topk_inds, :]
conf_pred = conf_pred[:, topk_inds]
# Save the result of current scale
multi_lvl_bboxes.append(bbox_pred)
multi_lvl_cls_scores.append(cls_pred)
multi_lvl_conf_scores.append(conf_pred)
# Merge the results of different scales together
batch_mlvl_bboxes = torch.cat(multi_lvl_bboxes, dim=1)
batch_mlvl_scores = torch.cat(multi_lvl_cls_scores, dim=1)
batch_mlvl_conf_scores = torch.cat(multi_lvl_conf_scores, dim=1)
post_params = get_mmdet_params(ctx.cfg)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
confidence_threshold = cfg.get('conf_thr',
post_params.confidence_threshold)
# helper function for creating Threshold op
def _create_threshold(x, thresh):
class ThresholdOp(torch.autograd.Function):
"""Create Threshold op."""
@staticmethod
def forward(ctx, x, threshold):
return x > threshold
@staticmethod
def symbolic(g, x, threshold):
return g.op(
'mmdeploy::Threshold', x, threshold_f=threshold, outputs=1)
return ThresholdOp.apply(x, thresh)
# follow original pipeline of YOLOv3
if confidence_threshold > 0:
mask = _create_threshold(batch_mlvl_conf_scores,
confidence_threshold).float()
batch_mlvl_conf_scores *= mask
if score_threshold > 0:
mask = _create_threshold(batch_mlvl_scores, score_threshold).float()
batch_mlvl_scores *= mask
# NCNN broadcast needs the same in channel dimension.
_batch_mlvl_conf_scores = batch_mlvl_conf_scores.unsqueeze(2).unsqueeze(3)
_batch_mlvl_scores = batch_mlvl_scores.unsqueeze(3)
batch_mlvl_scores = (_batch_mlvl_scores * _batch_mlvl_conf_scores).reshape(
batch_mlvl_scores.shape)
# Although batch_mlvl_bboxes already has the shape of
# (batch_size, -1, 4), ncnn implicit batch mechanism in the model and
# ncnn channel alignment would result in a shape of
# (batch_size, -1, 4, 1). So, we need a reshape op to ensure the
# batch_mlvl_bboxes shape is right.
batch_mlvl_bboxes = batch_mlvl_bboxes.reshape(batch_size, -1, 4)
if with_nms:
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
# keep aligned with original pipeline, improve
# mAP by 1% for YOLOv3 in ONNX
score_threshold = 0
return multiclass_nms(
batch_mlvl_bboxes,
batch_mlvl_scores,
max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)
else:
return batch_mlvl_bboxes, batch_mlvl_scores

View File

@ -786,3 +786,97 @@ def test_cascade_roi_head_with_mask(backend_type):
assert np.allclose( assert np.allclose(
expected_segm_results, segm_results, rtol=1e-03, expected_segm_results, segm_results, rtol=1e-03,
atol=1e-05), 'segm_results do not match.' atol=1e-05), 'segm_results do not match.'
def get_yolov3_head_model():
"""yolov3 Head Config."""
test_cfg = mmcv.Config(
dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
conf_thr=0.005,
nms=dict(type='nms', iou_threshold=0.45),
max_per_img=100))
from mmdet.models import YOLOV3Head
model = YOLOV3Head(
num_classes=4,
in_channels=[16, 8, 4],
out_channels=[32, 16, 8],
test_cfg=test_cfg)
model.requires_grad_(False)
return model
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn', 'openvino'])
def test_yolov3_head_get_bboxes(backend_type):
"""Test get_bboxes rewrite of yolov3 head."""
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
yolov3_head = get_yolov3_head_model()
yolov3_head.cpu().eval()
s = 128
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
output_names = ['dets', 'labels']
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.45,
confidence_threshold=0.005,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=100,
background_label_id=-1,
))))
seed_everything(1234)
pred_maps = [
torch.rand(1, 27, 5, 5),
torch.rand(1, 27, 10, 10),
torch.rand(1, 27, 20, 20)
]
# to get outputs of pytorch model
model_inputs = {'pred_maps': pred_maps, 'img_metas': img_metas}
model_outputs = get_model_outputs(yolov3_head, 'get_bboxes', model_inputs)
# to get outputs of onnx model after rewrite
wrapped_model = WrapModel(
yolov3_head, 'get_bboxes', img_metas=img_metas[0], with_nms=True)
rewrite_inputs = {
'pred_maps': pred_maps,
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
if isinstance(rewrite_outputs, dict):
rewrite_outputs = [
value for name, value in rewrite_outputs.items()
if name in output_names
]
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]],
rewrite_output,
rtol=1e-03,
atol=1e-05)
else:
assert rewrite_outputs is not None