mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Feature]: Support yolov3 (#167)
* support yolov3 with ort and trt * add ncnn compare * fix yolo_head ncnn rewriter * align perforance with ort trt for yolov3 * update doc * add test for compare with equal,less, greater * change namespace * reformat cpp * fix lint * fix lint * add unit test for yolov3 head * remove compare op * update doc * update table format in docs * update comments * update Co-authored-by: hanrui1sensetime <hanrui1@sensetime.com>
This commit is contained in:
parent
2ffc657665
commit
cd51f12f32
@ -3403,8 +3403,8 @@ int main(int argc, char** argv) {
|
|||||||
fprintf(pp, "%-16s", "LayerNorm");
|
fprintf(pp, "%-16s", "LayerNorm");
|
||||||
} else if (op == "LeakyRelu") {
|
} else if (op == "LeakyRelu") {
|
||||||
fprintf(pp, "%-16s", "ReLU");
|
fprintf(pp, "%-16s", "ReLU");
|
||||||
} else if (op == "Less") {
|
} else if (op == "Threshold") {
|
||||||
fprintf(pp, "%-16s", "Compare");
|
fprintf(pp, "%-16s", "Threshold");
|
||||||
} else if (op == "Log") {
|
} else if (op == "Log") {
|
||||||
fprintf(pp, "%-16s", "UnaryOp");
|
fprintf(pp, "%-16s", "UnaryOp");
|
||||||
} else if (op == "LRN") {
|
} else if (op == "LRN") {
|
||||||
@ -4315,11 +4315,10 @@ int main(int argc, char** argv) {
|
|||||||
}
|
}
|
||||||
} else if (op == "LeakyRelu") {
|
} else if (op == "LeakyRelu") {
|
||||||
float alpha = get_node_attr_f(node, "alpha", 0.01f);
|
float alpha = get_node_attr_f(node, "alpha", 0.01f);
|
||||||
|
|
||||||
fprintf(pp, " 0=%e", alpha);
|
fprintf(pp, " 0=%e", alpha);
|
||||||
} else if (op == "Less") {
|
} else if (op == "Threshold") {
|
||||||
int op_type = 1;
|
float threshold = get_node_attr_f(node, "threshold", 0.f);
|
||||||
fprintf(pp, " 0=%d", op_type);
|
fprintf(pp, " 0=%e", threshold);
|
||||||
} else if (op == "Log") {
|
} else if (op == "Log") {
|
||||||
int op_type = 8;
|
int op_type = 8;
|
||||||
fprintf(pp, " 0=%d", op_type);
|
fprintf(pp, " 0=%d", op_type);
|
||||||
|
@ -6,6 +6,7 @@ codebase_config = dict(
|
|||||||
task='ObjectDetection',
|
task='ObjectDetection',
|
||||||
post_processing=dict(
|
post_processing=dict(
|
||||||
score_threshold=0.05,
|
score_threshold=0.05,
|
||||||
|
confidence_threshold=0.005, # for YOLOv3
|
||||||
iou_threshold=0.5,
|
iou_threshold=0.5,
|
||||||
max_output_boxes_per_class=200,
|
max_output_boxes_per_class=200,
|
||||||
pre_top_k=-1,
|
pre_top_k=-1,
|
||||||
|
@ -0,0 +1,12 @@
|
|||||||
|
_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/tensorrt.py']
|
||||||
|
|
||||||
|
backend_config = dict(
|
||||||
|
common_config=dict(max_workspace_size=1 << 30),
|
||||||
|
model_inputs=[
|
||||||
|
dict(
|
||||||
|
input_shapes=dict(
|
||||||
|
input=dict(
|
||||||
|
min_shape=[1, 3, 160, 160],
|
||||||
|
opt_shape=[1, 3, 608, 608],
|
||||||
|
max_shape=[1, 3, 608, 608])))
|
||||||
|
])
|
@ -2,15 +2,16 @@
|
|||||||
|
|
||||||
<!-- TOC -->
|
<!-- TOC -->
|
||||||
|
|
||||||
- [Tutorial : How to convert model](#how-to-convert-model)
|
- [How to convert model](#how-to-convert-model)
|
||||||
- [How to convert models from Pytorch to BACKEND](#how-to-convert-models-from-pytorch-to-other-backends)
|
- [How to convert models from Pytorch to other backends](#how-to-convert-models-from-pytorch-to-other-backends)
|
||||||
- [Prerequisite](#prerequisite)
|
- [Prerequisite](#prerequisite)
|
||||||
- [Usage](#usage)
|
- [Usage](#usage)
|
||||||
- [Description of all arguments](#description-of-all-arguments)
|
- [Description of all arguments](#description-of-all-arguments)
|
||||||
- [How to evaluate the exported models](#how-to-evaluate-the-exported-models)
|
- [Example](#example)
|
||||||
- [List of supported models exportable to BACKEND](#list-of-supported-models-exportable-to-other-backends)
|
- [How to evaluate the exported models](#how-to-evaluate-the-exported-models)
|
||||||
- [Reminders](#reminders)
|
- [List of supported models exportable to other backends](#list-of-supported-models-exportable-to-other-backends)
|
||||||
- [FAQs](#faqs)
|
- [Reminders](#reminders)
|
||||||
|
- [FAQs](#faqs)
|
||||||
|
|
||||||
<!-- TOC -->
|
<!-- TOC -->
|
||||||
|
|
||||||
@ -78,33 +79,33 @@ You can try to evaluate model, referring to [how_to_evaluate_a_model](./how_to_e
|
|||||||
|
|
||||||
The table below lists the models that are guaranteed to be exportable to other backend.
|
The table below lists the models that are guaranteed to be exportable to other backend.
|
||||||
|
|
||||||
| Model | codebase | model config file(example) | OnnxRuntime | TensorRT | NCNN | PPL | OpenVINO |
|
| Model | codebase | OnnxRuntime | TensorRT | NCNN | PPL | OpenVINO | model config file(example) |
|
||||||
| :----------------: | :--------------: | :---------------------------------------------------------------------------------------: | :---------: | :-----------: | :---:| :---: | :-------: |
|
|--------------------|------------------|:-----------:|:--------:|:----:|:---:|:--------:|:--------------------------------------------------------------------------------------|
|
||||||
| RetinaNet | MMDetection | $PATH_TO_MMDET/configs/retinanet/retinanet_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y |
|
| RetinaNet | MMDetection | Y | Y | Y | Y | Y | $MMDET_DIR/configs/retinanet/retinanet_r50_fpn_1x_coco.py |
|
||||||
| Faster R-CNN | MMDetection | $PATH_TO_MMDET/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y |
|
| Faster R-CNN | MMDetection | Y | Y | Y | Y | Y | $MMDET_DIR/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py |
|
||||||
| YOLOv3 | MMDetection | $PATH_TO_MMDET/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py | Y | Y | N | Y | N |
|
| YOLOv3 | MMDetection | Y | Y | Y | Y | Y | $MMDET_DIR/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py |
|
||||||
| FCOS | MMDetection | $PATH_TO_MMDET/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py | Y | Y | Y | N | Y |
|
| FCOS | MMDetection | Y | Y | Y | N | Y | $MMDET_DIR/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py |
|
||||||
| FSAF | MMDetection | $PATH_TO_MMDET/configs/fsaf/fsaf_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y |
|
| FSAF | MMDetection | Y | Y | Y | Y | Y | $MMDET_DIR/configs/fsaf/fsaf_r50_fpn_1x_coco.py |
|
||||||
| Mask R-CNN | MMDetection | $PATH_TO_MMDET/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py | Y | Y | N | Y | Y |
|
| Mask R-CNN | MMDetection | Y | Y | N | Y | Y | $MMDET_DIR/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py |
|
||||||
| SSD | MMDetection | $PATH_TO_MMDET/configs/ssd/ssd300_coco.py | Y | ? | ? | ? | Y |
|
| SSD | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/ssd/ssd300_coco.py |
|
||||||
| FoveaBox | MMDetection | $PATH_TO_MMDET/configs/foveabox/fovea_r50_fpn_4x4_1x_coco.py | Y | ? | ? | ? | Y |
|
| FoveaBox | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/foveabox/fovea_r50_fpn_4x4_1x_coco.py |
|
||||||
| ATSS | MMDetection | $PATH_TO_MMDET/configs/atss/atss_r50_fpn_1x_coco.py | Y | ? | ? | ? | Y |
|
| ATSS | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/atss/atss_r50_fpn_1x_coco.py |
|
||||||
| Cascade R-CNN | MMDetection | $PATH_TO_MMDET/configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py | Y | ? | ? | ? | Y |
|
| Cascade R-CNN | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py |
|
||||||
| Cascade Mask R-CNN | MMDetection | $PATH_TO_MMDET/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py | Y | ? | ? | ? | Y |
|
| Cascade Mask R-CNN | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py |
|
||||||
| ResNet | MMClassification | $PATH_TO_MMCLS/configs/resnet/resnet18_b32x8_imagenet.py | Y | Y | Y | Y | N |
|
| ResNet | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/resnet/resnet18_b32x8_imagenet.py |
|
||||||
| ResNeXt | MMClassification | $PATH_TO_MMCLS/configs/resnext/resnext50_32x4d_b32x8_imagenet.py | Y | Y | Y | Y | N |
|
| ResNeXt | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/resnext/resnext50_32x4d_b32x8_imagenet.py |
|
||||||
| SE-ResNet | MMClassification | $PATH_TO_MMCLS/configs/seresnet/seresnet50_b32x8_imagenet.py | Y | Y | Y | Y | N |
|
| SE-ResNet | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/seresnet/seresnet50_b32x8_imagenet.py |
|
||||||
| MobileNetV2 | MMClassification | $PATH_TO_MMCLS/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py | Y | Y | Y | Y | N |
|
| MobileNetV2 | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py |
|
||||||
| ShuffleNetV1 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y | N |
|
| ShuffleNetV1 | MMClassification | Y | Y | N | Y | N | $MMCLS_DIR/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py |
|
||||||
| ShuffleNetV2 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y | N |
|
| ShuffleNetV2 | MMClassification | Y | Y | N | Y | N | $MMCLS_DIR/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py |
|
||||||
| FCN | MMSegmentation | $PATH_TO_MMSEG/configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N |
|
| FCN | MMSegmentation | Y | Y | Y | Y | N | $MMSEG_DIR/configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py |
|
||||||
| PSPNet | MMSegmentation | $PATH_TO_MMSEG/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py | Y | Y | N | Y | N |
|
| PSPNet | MMSegmentation | Y | Y | N | Y | N | $MMSEG_DIR/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py |
|
||||||
| DeepLabV3 | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N |
|
| DeepLabV3 | MMSegmentation | Y | Y | Y | Y | N | $MMSEG_DIR/configs/deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py |
|
||||||
| DeepLabV3+ | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N |
|
| DeepLabV3+ | MMSegmentation | Y | Y | Y | Y | N | $MMSEG_DIR/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py |
|
||||||
| SRCNN | MMEditing | $PATH_TO_MMSEG/configs/restorers/srcnn/srcnn_x4k915_g1_1000k_div2k.py | Y | Y | N | Y | N |
|
| SRCNN | MMEditing | Y | Y | N | Y | N | $MMSEG_DIR/configs/restorers/srcnn/srcnn_x4k915_g1_1000k_div2k.py |
|
||||||
| ESRGAN | MMEditing | $PATH_TO_MMSEG/configs/restorers/esrgan/esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py | Y | Y | N | Y | N |
|
| ESRGAN | MMEditing | Y | Y | N | Y | N | $MMSEG_DIR/configs/restorers/esrgan/esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py |
|
||||||
| DBNet | MMOCR | $PATH_TO_MMOCR/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py | Y | Y | Y | Y | N |
|
| DBNet | MMOCR | Y | Y | Y | Y | N | $MMOCR_DIR/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py |
|
||||||
| CRNN | MMOCR | $PATH_TO_MMOCR/configs/textrecog/tps/crnn_tps_academic_dataset.py | Y | Y | Y | N | N |
|
| CRNN | MMOCR | Y | Y | Y | N | N | $MMOCR_DIR/configs/textrecog/tps/crnn_tps_academic_dataset.py |
|
||||||
|
|
||||||
### Reminders
|
### Reminders
|
||||||
|
|
||||||
|
@ -20,7 +20,8 @@ class ONNXNMSop(torch.autograd.Function):
|
|||||||
for batch_id in range(batch_size):
|
for batch_id in range(batch_size):
|
||||||
for cls_id in range(num_class):
|
for cls_id in range(num_class):
|
||||||
_boxes = boxes[batch_id, ...]
|
_boxes = boxes[batch_id, ...]
|
||||||
_scores = scores[batch_id, cls_id, ...]
|
# score_threshold=0 requires scores to be contiguous
|
||||||
|
_scores = scores[batch_id, cls_id, ...].contiguous()
|
||||||
_, box_inds = nms(
|
_, box_inds = nms(
|
||||||
_boxes,
|
_boxes,
|
||||||
_scores,
|
_scores,
|
||||||
|
@ -3,9 +3,11 @@ from .atss_head import get_bboxes_of_atss_head
|
|||||||
from .fcos_head import get_bboxes_of_fcos_head
|
from .fcos_head import get_bboxes_of_fcos_head
|
||||||
from .fovea_head import get_bboxes_of_fovea_head
|
from .fovea_head import get_bboxes_of_fovea_head
|
||||||
from .rpn_head import get_bboxes_of_rpn_head
|
from .rpn_head import get_bboxes_of_rpn_head
|
||||||
|
from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'get_bboxes_of_anchor_head', 'get_bboxes_of_fcos_head',
|
'get_bboxes_of_anchor_head', 'get_bboxes_of_fcos_head',
|
||||||
'get_bboxes_of_rpn_head', 'get_bboxes_of_fovea_head',
|
'get_bboxes_of_rpn_head', 'get_bboxes_of_fovea_head',
|
||||||
'get_bboxes_of_atss_head'
|
'get_bboxes_of_atss_head', 'yolov3_head__get_bboxes',
|
||||||
|
'yolov3_head__get_bboxes__ncnn'
|
||||||
]
|
]
|
||||||
|
@ -102,10 +102,11 @@ def get_bboxes_of_fcos_head(ctx,
|
|||||||
max_output_boxes_per_class = post_params.max_output_boxes_per_class
|
max_output_boxes_per_class = post_params.max_output_boxes_per_class
|
||||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||||
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||||
nms_pre = cfg.get('deploy_nms_pre', -1)
|
pre_top_k = post_params.pre_top_k
|
||||||
|
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||||
return multiclass_nms(batch_mlvl_bboxes, batch_mlvl_scores,
|
return multiclass_nms(batch_mlvl_bboxes, batch_mlvl_scores,
|
||||||
max_output_boxes_per_class, iou_threshold,
|
max_output_boxes_per_class, iou_threshold,
|
||||||
score_threshold, nms_pre, cfg.max_per_img)
|
score_threshold, pre_top_k, keep_top_k)
|
||||||
|
|
||||||
|
|
||||||
@FUNCTION_REWRITER.register_rewriter(
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
@ -195,7 +196,8 @@ def get_bboxes_of_fcos_head_ncnn(ctx,
|
|||||||
max_output_boxes_per_class = post_params.max_output_boxes_per_class
|
max_output_boxes_per_class = post_params.max_output_boxes_per_class
|
||||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||||
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||||
nms_pre = cfg.get('deploy_nms_pre', -1)
|
pre_top_k = post_params.pre_top_k
|
||||||
|
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||||
return multiclass_nms(batch_mlvl_bboxes, batch_mlvl_scores,
|
return multiclass_nms(batch_mlvl_bboxes, batch_mlvl_scores,
|
||||||
max_output_boxes_per_class, iou_threshold,
|
max_output_boxes_per_class, iou_threshold,
|
||||||
score_threshold, nms_pre, cfg.max_per_img)
|
score_threshold, pre_top_k, keep_top_k)
|
||||||
|
309
mmdeploy/mmdet/models/dense_heads/yolo_head.py
Normal file
309
mmdeploy/mmdet/models/dense_heads/yolo_head.py
Normal file
@ -0,0 +1,309 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from mmdeploy.core import FUNCTION_REWRITER
|
||||||
|
from mmdeploy.mmdet.core import multiclass_nms
|
||||||
|
from mmdeploy.mmdet.export import pad_with_value
|
||||||
|
from mmdeploy.utils import (Backend, get_backend, get_mmdet_params,
|
||||||
|
is_dynamic_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
|
func_name='mmdet.models.YOLOV3Head.get_bboxes')
|
||||||
|
def yolov3_head__get_bboxes(ctx,
|
||||||
|
self,
|
||||||
|
pred_maps,
|
||||||
|
with_nms=True,
|
||||||
|
cfg=None,
|
||||||
|
**kwargs):
|
||||||
|
"""Rewrite `get_bboxes` for default backend.
|
||||||
|
|
||||||
|
Transform network output for a batch into bbox predictions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Context that contains original meta information.
|
||||||
|
self: Represent the instance of the original class.
|
||||||
|
pred_maps (list[Tensor]): Raw predictions for a batch of images.
|
||||||
|
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||||
|
if None, test_cfg would be used. Default: None.
|
||||||
|
with_nms (bool): If True, do nms before return boxes.
|
||||||
|
Default: True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor,
|
||||||
|
where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch
|
||||||
|
size and the score between 0 and 1. The shape of the second
|
||||||
|
tensor in the tuple is (N, num_box), and each element
|
||||||
|
represents the class label of the corresponding box.
|
||||||
|
"""
|
||||||
|
is_dynamic_flag = is_dynamic_shape(ctx.cfg)
|
||||||
|
num_levels = len(pred_maps)
|
||||||
|
pred_maps_list = [pred_maps[i].detach() for i in range(num_levels)]
|
||||||
|
|
||||||
|
cfg = self.test_cfg if cfg is None else cfg
|
||||||
|
assert len(pred_maps_list) == self.num_levels
|
||||||
|
|
||||||
|
device = pred_maps_list[0].device
|
||||||
|
batch_size = pred_maps_list[0].shape[0]
|
||||||
|
|
||||||
|
featmap_sizes = [
|
||||||
|
pred_maps_list[i].shape[-2:] for i in range(self.num_levels)
|
||||||
|
]
|
||||||
|
multi_lvl_anchors = self.anchor_generator.grid_anchors(
|
||||||
|
featmap_sizes, device)
|
||||||
|
pre_topk = cfg.get('nms_pre', -1)
|
||||||
|
multi_lvl_bboxes = []
|
||||||
|
multi_lvl_cls_scores = []
|
||||||
|
multi_lvl_conf_scores = []
|
||||||
|
for i in range(self.num_levels):
|
||||||
|
# get some key info for current scale
|
||||||
|
pred_map = pred_maps_list[i]
|
||||||
|
stride = self.featmap_strides[i]
|
||||||
|
# (b,h, w, num_anchors*num_attrib) ->
|
||||||
|
# (b,h*w*num_anchors, num_attrib)
|
||||||
|
pred_map = pred_map.permute(0, 2, 3,
|
||||||
|
1).reshape(batch_size, -1, self.num_attrib)
|
||||||
|
# Inplace operation like
|
||||||
|
# ```pred_map[..., :2] = \torch.sigmoid(pred_map[..., :2])```
|
||||||
|
# would create constant tensor when exporting to onnx
|
||||||
|
pred_map_conf = torch.sigmoid(pred_map[..., :2])
|
||||||
|
pred_map_rest = pred_map[..., 2:]
|
||||||
|
pred_map = torch.cat([pred_map_conf, pred_map_rest], dim=-1)
|
||||||
|
pred_map_boxes = pred_map[..., :4]
|
||||||
|
multi_lvl_anchor = multi_lvl_anchors[i]
|
||||||
|
# use static anchor if input shape is static
|
||||||
|
if not is_dynamic_flag:
|
||||||
|
multi_lvl_anchor = multi_lvl_anchor.data
|
||||||
|
multi_lvl_anchor = multi_lvl_anchor.unsqueeze(0).expand_as(
|
||||||
|
pred_map_boxes)
|
||||||
|
bbox_pred = self.bbox_coder.decode(multi_lvl_anchor, pred_map_boxes,
|
||||||
|
stride)
|
||||||
|
# conf and cls
|
||||||
|
conf_pred = torch.sigmoid(pred_map[..., 4])
|
||||||
|
cls_pred = torch.sigmoid(pred_map[..., 5:]).view(
|
||||||
|
batch_size, -1, self.num_classes) # Cls pred one-hot.
|
||||||
|
|
||||||
|
backend = get_backend(ctx.cfg)
|
||||||
|
# topk in tensorrt does not support shape<k
|
||||||
|
# concate zero to enable topk,
|
||||||
|
if backend == Backend.TENSORRT:
|
||||||
|
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
|
||||||
|
conf_pred = pad_with_value(conf_pred, 1, pre_topk, 0.)
|
||||||
|
cls_pred = pad_with_value(cls_pred, 1, pre_topk, 0.)
|
||||||
|
|
||||||
|
if pre_topk > 0:
|
||||||
|
_, topk_inds = conf_pred.topk(pre_topk)
|
||||||
|
batch_inds = torch.arange(
|
||||||
|
batch_size, device=device).view(-1,
|
||||||
|
1).expand_as(topk_inds).long()
|
||||||
|
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
|
||||||
|
transformed_inds = (bbox_pred.shape[1] * batch_inds + topk_inds)
|
||||||
|
bbox_pred = bbox_pred.reshape(-1, 4)[transformed_inds, :].reshape(
|
||||||
|
batch_size, -1, 4)
|
||||||
|
cls_pred = cls_pred.reshape(
|
||||||
|
-1, self.num_classes)[transformed_inds, :].reshape(
|
||||||
|
batch_size, -1, self.num_classes)
|
||||||
|
conf_pred = conf_pred.reshape(-1, 1)[transformed_inds].reshape(
|
||||||
|
batch_size, -1)
|
||||||
|
|
||||||
|
# Save the result of current scale
|
||||||
|
multi_lvl_bboxes.append(bbox_pred)
|
||||||
|
multi_lvl_cls_scores.append(cls_pred)
|
||||||
|
multi_lvl_conf_scores.append(conf_pred)
|
||||||
|
|
||||||
|
# Merge the results of different scales together
|
||||||
|
batch_mlvl_bboxes = torch.cat(multi_lvl_bboxes, dim=1)
|
||||||
|
batch_mlvl_scores = torch.cat(multi_lvl_cls_scores, dim=1)
|
||||||
|
batch_mlvl_conf_scores = torch.cat(multi_lvl_conf_scores, dim=1)
|
||||||
|
|
||||||
|
post_params = get_mmdet_params(ctx.cfg)
|
||||||
|
|
||||||
|
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||||
|
confidence_threshold = cfg.get('conf_thr',
|
||||||
|
post_params.confidence_threshold)
|
||||||
|
|
||||||
|
# follow original pipeline of YOLOv3
|
||||||
|
if confidence_threshold > 0:
|
||||||
|
mask = (batch_mlvl_conf_scores >= confidence_threshold).float()
|
||||||
|
batch_mlvl_conf_scores *= mask
|
||||||
|
if score_threshold > 0:
|
||||||
|
mask = (batch_mlvl_scores > score_threshold).float()
|
||||||
|
batch_mlvl_scores *= mask
|
||||||
|
|
||||||
|
batch_mlvl_conf_scores = batch_mlvl_conf_scores.unsqueeze(2).expand_as(
|
||||||
|
batch_mlvl_scores)
|
||||||
|
batch_mlvl_scores = batch_mlvl_scores * batch_mlvl_conf_scores
|
||||||
|
|
||||||
|
if with_nms:
|
||||||
|
max_output_boxes_per_class = post_params.max_output_boxes_per_class
|
||||||
|
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||||
|
pre_top_k = post_params.pre_top_k
|
||||||
|
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||||
|
# keep aligned with original pipeline, improve
|
||||||
|
# mAP by 1% for YOLOv3 in ONNX
|
||||||
|
score_threshold = 0
|
||||||
|
return multiclass_nms(
|
||||||
|
batch_mlvl_bboxes,
|
||||||
|
batch_mlvl_scores,
|
||||||
|
max_output_boxes_per_class,
|
||||||
|
iou_threshold=iou_threshold,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
pre_top_k=pre_top_k,
|
||||||
|
keep_top_k=keep_top_k)
|
||||||
|
else:
|
||||||
|
return batch_mlvl_bboxes, batch_mlvl_scores
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
|
func_name='mmdet.models.YOLOV3Head.get_bboxes', backend='ncnn')
|
||||||
|
def yolov3_head__get_bboxes__ncnn(ctx,
|
||||||
|
self,
|
||||||
|
pred_maps,
|
||||||
|
with_nms=True,
|
||||||
|
cfg=None,
|
||||||
|
**kwargs):
|
||||||
|
"""Rewrite `get_bboxes` for ncnn backend.
|
||||||
|
|
||||||
|
Transform network output for a batch into bbox predictions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Context that contains original meta information.
|
||||||
|
self: Represent the instance of the original class.
|
||||||
|
pred_maps (list[Tensor]): Raw predictions for a batch of images.
|
||||||
|
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||||
|
if None, test_cfg would be used. Default: None.
|
||||||
|
with_nms (bool): If True, do nms before return boxes.
|
||||||
|
Default: True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor,
|
||||||
|
where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch
|
||||||
|
size and the score between 0 and 1. The shape of the second
|
||||||
|
tensor in the tuple is (N, num_box), and each element
|
||||||
|
represents the class label of the corresponding box.
|
||||||
|
"""
|
||||||
|
num_levels = len(pred_maps)
|
||||||
|
pred_maps_list = [pred_maps[i].detach() for i in range(num_levels)]
|
||||||
|
|
||||||
|
cfg = self.test_cfg if cfg is None else cfg
|
||||||
|
assert len(pred_maps_list) == self.num_levels
|
||||||
|
|
||||||
|
device = pred_maps_list[0].device
|
||||||
|
batch_size = pred_maps_list[0].shape[0]
|
||||||
|
|
||||||
|
featmap_sizes = [
|
||||||
|
pred_maps_list[i].shape[-2:] for i in range(self.num_levels)
|
||||||
|
]
|
||||||
|
multi_lvl_anchors = self.anchor_generator.grid_anchors(
|
||||||
|
featmap_sizes, device)
|
||||||
|
pre_topk = cfg.get('nms_pre', -1)
|
||||||
|
multi_lvl_bboxes = []
|
||||||
|
multi_lvl_cls_scores = []
|
||||||
|
multi_lvl_conf_scores = []
|
||||||
|
for i in range(self.num_levels):
|
||||||
|
# get some key info for current scale
|
||||||
|
pred_map = pred_maps_list[i]
|
||||||
|
stride = self.featmap_strides[i]
|
||||||
|
# (b,h, w, num_anchors*num_attrib) ->
|
||||||
|
# (b,h*w*num_anchors, num_attrib)
|
||||||
|
pred_map = pred_map.permute(0, 2, 3,
|
||||||
|
1).reshape(batch_size, -1, self.num_attrib)
|
||||||
|
# Inplace operation like
|
||||||
|
# ```pred_map[..., :2] = \torch.sigmoid(pred_map[..., :2])```
|
||||||
|
# would create constant tensor when exporting to onnx
|
||||||
|
pred_map_conf = torch.sigmoid(pred_map[..., :2])
|
||||||
|
pred_map_rest = pred_map[..., 2:]
|
||||||
|
# dim must be written as 2, but not -1, because ncnn implicit batch
|
||||||
|
# mechanism.
|
||||||
|
pred_map = torch.cat([pred_map_conf, pred_map_rest], dim=2)
|
||||||
|
pred_map_boxes = pred_map[..., :4]
|
||||||
|
multi_lvl_anchor = multi_lvl_anchors[i]
|
||||||
|
# use static anchor if input shape is static
|
||||||
|
multi_lvl_anchor = multi_lvl_anchor.unsqueeze(0).expand_as(
|
||||||
|
pred_map_boxes).data
|
||||||
|
|
||||||
|
bbox_pred = self.bbox_coder.decode(multi_lvl_anchor, pred_map_boxes,
|
||||||
|
stride)
|
||||||
|
# conf and cls
|
||||||
|
conf_pred = torch.sigmoid(pred_map[..., 4])
|
||||||
|
cls_pred = torch.sigmoid(pred_map[..., 5:]).view(
|
||||||
|
batch_size, -1, self.num_classes) # Cls pred one-hot.
|
||||||
|
|
||||||
|
if pre_topk > 0:
|
||||||
|
_, topk_inds = conf_pred.topk(pre_topk)
|
||||||
|
topk_inds = topk_inds.view(-1)
|
||||||
|
bbox_pred = bbox_pred[:, topk_inds, :]
|
||||||
|
cls_pred = cls_pred[:, topk_inds, :]
|
||||||
|
conf_pred = conf_pred[:, topk_inds]
|
||||||
|
|
||||||
|
# Save the result of current scale
|
||||||
|
multi_lvl_bboxes.append(bbox_pred)
|
||||||
|
multi_lvl_cls_scores.append(cls_pred)
|
||||||
|
multi_lvl_conf_scores.append(conf_pred)
|
||||||
|
|
||||||
|
# Merge the results of different scales together
|
||||||
|
batch_mlvl_bboxes = torch.cat(multi_lvl_bboxes, dim=1)
|
||||||
|
batch_mlvl_scores = torch.cat(multi_lvl_cls_scores, dim=1)
|
||||||
|
batch_mlvl_conf_scores = torch.cat(multi_lvl_conf_scores, dim=1)
|
||||||
|
|
||||||
|
post_params = get_mmdet_params(ctx.cfg)
|
||||||
|
|
||||||
|
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||||
|
confidence_threshold = cfg.get('conf_thr',
|
||||||
|
post_params.confidence_threshold)
|
||||||
|
|
||||||
|
# helper function for creating Threshold op
|
||||||
|
def _create_threshold(x, thresh):
|
||||||
|
|
||||||
|
class ThresholdOp(torch.autograd.Function):
|
||||||
|
"""Create Threshold op."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, threshold):
|
||||||
|
return x > threshold
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def symbolic(g, x, threshold):
|
||||||
|
return g.op(
|
||||||
|
'mmdeploy::Threshold', x, threshold_f=threshold, outputs=1)
|
||||||
|
|
||||||
|
return ThresholdOp.apply(x, thresh)
|
||||||
|
|
||||||
|
# follow original pipeline of YOLOv3
|
||||||
|
if confidence_threshold > 0:
|
||||||
|
mask = _create_threshold(batch_mlvl_conf_scores,
|
||||||
|
confidence_threshold).float()
|
||||||
|
batch_mlvl_conf_scores *= mask
|
||||||
|
if score_threshold > 0:
|
||||||
|
mask = _create_threshold(batch_mlvl_scores, score_threshold).float()
|
||||||
|
batch_mlvl_scores *= mask
|
||||||
|
|
||||||
|
# NCNN broadcast needs the same in channel dimension.
|
||||||
|
_batch_mlvl_conf_scores = batch_mlvl_conf_scores.unsqueeze(2).unsqueeze(3)
|
||||||
|
_batch_mlvl_scores = batch_mlvl_scores.unsqueeze(3)
|
||||||
|
batch_mlvl_scores = (_batch_mlvl_scores * _batch_mlvl_conf_scores).reshape(
|
||||||
|
batch_mlvl_scores.shape)
|
||||||
|
# Although batch_mlvl_bboxes already has the shape of
|
||||||
|
# (batch_size, -1, 4), ncnn implicit batch mechanism in the model and
|
||||||
|
# ncnn channel alignment would result in a shape of
|
||||||
|
# (batch_size, -1, 4, 1). So, we need a reshape op to ensure the
|
||||||
|
# batch_mlvl_bboxes shape is right.
|
||||||
|
batch_mlvl_bboxes = batch_mlvl_bboxes.reshape(batch_size, -1, 4)
|
||||||
|
|
||||||
|
if with_nms:
|
||||||
|
max_output_boxes_per_class = post_params.max_output_boxes_per_class
|
||||||
|
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||||
|
pre_top_k = post_params.pre_top_k
|
||||||
|
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||||
|
# keep aligned with original pipeline, improve
|
||||||
|
# mAP by 1% for YOLOv3 in ONNX
|
||||||
|
score_threshold = 0
|
||||||
|
return multiclass_nms(
|
||||||
|
batch_mlvl_bboxes,
|
||||||
|
batch_mlvl_scores,
|
||||||
|
max_output_boxes_per_class,
|
||||||
|
iou_threshold=iou_threshold,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
pre_top_k=pre_top_k,
|
||||||
|
keep_top_k=keep_top_k)
|
||||||
|
else:
|
||||||
|
return batch_mlvl_bboxes, batch_mlvl_scores
|
@ -786,3 +786,97 @@ def test_cascade_roi_head_with_mask(backend_type):
|
|||||||
assert np.allclose(
|
assert np.allclose(
|
||||||
expected_segm_results, segm_results, rtol=1e-03,
|
expected_segm_results, segm_results, rtol=1e-03,
|
||||||
atol=1e-05), 'segm_results do not match.'
|
atol=1e-05), 'segm_results do not match.'
|
||||||
|
|
||||||
|
|
||||||
|
def get_yolov3_head_model():
|
||||||
|
"""yolov3 Head Config."""
|
||||||
|
test_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
nms_pre=1000,
|
||||||
|
min_bbox_size=0,
|
||||||
|
score_thr=0.05,
|
||||||
|
conf_thr=0.005,
|
||||||
|
nms=dict(type='nms', iou_threshold=0.45),
|
||||||
|
max_per_img=100))
|
||||||
|
from mmdet.models import YOLOV3Head
|
||||||
|
model = YOLOV3Head(
|
||||||
|
num_classes=4,
|
||||||
|
in_channels=[16, 8, 4],
|
||||||
|
out_channels=[32, 16, 8],
|
||||||
|
test_cfg=test_cfg)
|
||||||
|
|
||||||
|
model.requires_grad_(False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn', 'openvino'])
|
||||||
|
def test_yolov3_head_get_bboxes(backend_type):
|
||||||
|
"""Test get_bboxes rewrite of yolov3 head."""
|
||||||
|
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
|
||||||
|
yolov3_head = get_yolov3_head_model()
|
||||||
|
yolov3_head.cpu().eval()
|
||||||
|
s = 128
|
||||||
|
img_metas = [{
|
||||||
|
'scale_factor': np.ones(4),
|
||||||
|
'pad_shape': (s, s, 3),
|
||||||
|
'img_shape': (s, s, 3)
|
||||||
|
}]
|
||||||
|
|
||||||
|
output_names = ['dets', 'labels']
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(type=backend_type),
|
||||||
|
onnx_config=dict(output_names=output_names, input_shape=None),
|
||||||
|
codebase_config=dict(
|
||||||
|
type='mmdet',
|
||||||
|
task='ObjectDetection',
|
||||||
|
post_processing=dict(
|
||||||
|
score_threshold=0.05,
|
||||||
|
iou_threshold=0.45,
|
||||||
|
confidence_threshold=0.005,
|
||||||
|
max_output_boxes_per_class=200,
|
||||||
|
pre_top_k=-1,
|
||||||
|
keep_top_k=100,
|
||||||
|
background_label_id=-1,
|
||||||
|
))))
|
||||||
|
|
||||||
|
seed_everything(1234)
|
||||||
|
pred_maps = [
|
||||||
|
torch.rand(1, 27, 5, 5),
|
||||||
|
torch.rand(1, 27, 10, 10),
|
||||||
|
torch.rand(1, 27, 20, 20)
|
||||||
|
]
|
||||||
|
# to get outputs of pytorch model
|
||||||
|
model_inputs = {'pred_maps': pred_maps, 'img_metas': img_metas}
|
||||||
|
model_outputs = get_model_outputs(yolov3_head, 'get_bboxes', model_inputs)
|
||||||
|
|
||||||
|
# to get outputs of onnx model after rewrite
|
||||||
|
wrapped_model = WrapModel(
|
||||||
|
yolov3_head, 'get_bboxes', img_metas=img_metas[0], with_nms=True)
|
||||||
|
rewrite_inputs = {
|
||||||
|
'pred_maps': pred_maps,
|
||||||
|
}
|
||||||
|
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||||
|
wrapped_model=wrapped_model,
|
||||||
|
model_inputs=rewrite_inputs,
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
|
||||||
|
if is_backend_output:
|
||||||
|
if isinstance(rewrite_outputs, dict):
|
||||||
|
rewrite_outputs = [
|
||||||
|
value for name, value in rewrite_outputs.items()
|
||||||
|
if name in output_names
|
||||||
|
]
|
||||||
|
for model_output, rewrite_output in zip(model_outputs[0],
|
||||||
|
rewrite_outputs):
|
||||||
|
model_output = model_output.squeeze().cpu().numpy()
|
||||||
|
rewrite_output = rewrite_output.squeeze()
|
||||||
|
# hard code to make two tensors with the same shape
|
||||||
|
# rewrite and original codes applied different nms strategy
|
||||||
|
assert np.allclose(
|
||||||
|
model_output[:rewrite_output.shape[0]],
|
||||||
|
rewrite_output,
|
||||||
|
rtol=1e-03,
|
||||||
|
atol=1e-05)
|
||||||
|
else:
|
||||||
|
assert rewrite_outputs is not None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user