diff --git a/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp b/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp index e6ef8b500..c3ad05bf9 100644 --- a/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp +++ b/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp @@ -3403,8 +3403,8 @@ int main(int argc, char** argv) { fprintf(pp, "%-16s", "LayerNorm"); } else if (op == "LeakyRelu") { fprintf(pp, "%-16s", "ReLU"); - } else if (op == "Less") { - fprintf(pp, "%-16s", "Compare"); + } else if (op == "Threshold") { + fprintf(pp, "%-16s", "Threshold"); } else if (op == "Log") { fprintf(pp, "%-16s", "UnaryOp"); } else if (op == "LRN") { @@ -4315,11 +4315,10 @@ int main(int argc, char** argv) { } } else if (op == "LeakyRelu") { float alpha = get_node_attr_f(node, "alpha", 0.01f); - fprintf(pp, " 0=%e", alpha); - } else if (op == "Less") { - int op_type = 1; - fprintf(pp, " 0=%d", op_type); + } else if (op == "Threshold") { + float threshold = get_node_attr_f(node, "threshold", 0.f); + fprintf(pp, " 0=%e", threshold); } else if (op == "Log") { int op_type = 8; fprintf(pp, " 0=%d", op_type); diff --git a/configs/mmdet/_base_/base_static.py b/configs/mmdet/_base_/base_static.py index 2570694fd..e2cdf0f37 100644 --- a/configs/mmdet/_base_/base_static.py +++ b/configs/mmdet/_base_/base_static.py @@ -6,6 +6,7 @@ codebase_config = dict( task='ObjectDetection', post_processing=dict( score_threshold=0.05, + confidence_threshold=0.005, # for YOLOv3 iou_threshold=0.5, max_output_boxes_per_class=200, pre_top_k=-1, diff --git a/configs/mmdet/single-stage/single-stage_tensorrt_dynamic-160x160-608x608.py b/configs/mmdet/single-stage/single-stage_tensorrt_dynamic-160x160-608x608.py new file mode 100644 index 000000000..b23547e0b --- /dev/null +++ b/configs/mmdet/single-stage/single-stage_tensorrt_dynamic-160x160-608x608.py @@ -0,0 +1,12 @@ +_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/tensorrt.py'] + +backend_config = dict( + common_config=dict(max_workspace_size=1 << 30), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 160, 160], + opt_shape=[1, 3, 608, 608], + max_shape=[1, 3, 608, 608]))) + ]) diff --git a/docs/tutorials/how_to_convert_model.md b/docs/tutorials/how_to_convert_model.md index 65fe9d13d..a236d8ead 100644 --- a/docs/tutorials/how_to_convert_model.md +++ b/docs/tutorials/how_to_convert_model.md @@ -2,15 +2,16 @@ -- [Tutorial : How to convert model](#how-to-convert-model) - - [How to convert models from Pytorch to BACKEND](#how-to-convert-models-from-pytorch-to-other-backends) - - [Prerequisite](#prerequisite) - - [Usage](#usage) - - [Description of all arguments](#description-of-all-arguments) - - [How to evaluate the exported models](#how-to-evaluate-the-exported-models) - - [List of supported models exportable to BACKEND](#list-of-supported-models-exportable-to-other-backends) - - [Reminders](#reminders) - - [FAQs](#faqs) +- [How to convert model](#how-to-convert-model) + - [How to convert models from Pytorch to other backends](#how-to-convert-models-from-pytorch-to-other-backends) + - [Prerequisite](#prerequisite) + - [Usage](#usage) + - [Description of all arguments](#description-of-all-arguments) + - [Example](#example) + - [How to evaluate the exported models](#how-to-evaluate-the-exported-models) + - [List of supported models exportable to other backends](#list-of-supported-models-exportable-to-other-backends) + - [Reminders](#reminders) + - [FAQs](#faqs) @@ -78,33 +79,33 @@ You can try to evaluate model, referring to [how_to_evaluate_a_model](./how_to_e The table below lists the models that are guaranteed to be exportable to other backend. -| Model | codebase | model config file(example) | OnnxRuntime | TensorRT | NCNN | PPL | OpenVINO | -| :----------------: | :--------------: | :---------------------------------------------------------------------------------------: | :---------: | :-----------: | :---:| :---: | :-------: | -| RetinaNet | MMDetection | $PATH_TO_MMDET/configs/retinanet/retinanet_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y | -| Faster R-CNN | MMDetection | $PATH_TO_MMDET/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y | -| YOLOv3 | MMDetection | $PATH_TO_MMDET/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py | Y | Y | N | Y | N | -| FCOS | MMDetection | $PATH_TO_MMDET/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py | Y | Y | Y | N | Y | -| FSAF | MMDetection | $PATH_TO_MMDET/configs/fsaf/fsaf_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y | -| Mask R-CNN | MMDetection | $PATH_TO_MMDET/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py | Y | Y | N | Y | Y | -| SSD | MMDetection | $PATH_TO_MMDET/configs/ssd/ssd300_coco.py | Y | ? | ? | ? | Y | -| FoveaBox | MMDetection | $PATH_TO_MMDET/configs/foveabox/fovea_r50_fpn_4x4_1x_coco.py | Y | ? | ? | ? | Y | -| ATSS | MMDetection | $PATH_TO_MMDET/configs/atss/atss_r50_fpn_1x_coco.py | Y | ? | ? | ? | Y | -| Cascade R-CNN | MMDetection | $PATH_TO_MMDET/configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py | Y | ? | ? | ? | Y | -| Cascade Mask R-CNN | MMDetection | $PATH_TO_MMDET/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py | Y | ? | ? | ? | Y | -| ResNet | MMClassification | $PATH_TO_MMCLS/configs/resnet/resnet18_b32x8_imagenet.py | Y | Y | Y | Y | N | -| ResNeXt | MMClassification | $PATH_TO_MMCLS/configs/resnext/resnext50_32x4d_b32x8_imagenet.py | Y | Y | Y | Y | N | -| SE-ResNet | MMClassification | $PATH_TO_MMCLS/configs/seresnet/seresnet50_b32x8_imagenet.py | Y | Y | Y | Y | N | -| MobileNetV2 | MMClassification | $PATH_TO_MMCLS/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py | Y | Y | Y | Y | N | -| ShuffleNetV1 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y | N | -| ShuffleNetV2 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y | N | -| FCN | MMSegmentation | $PATH_TO_MMSEG/configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N | -| PSPNet | MMSegmentation | $PATH_TO_MMSEG/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py | Y | Y | N | Y | N | -| DeepLabV3 | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N | -| DeepLabV3+ | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N | -| SRCNN | MMEditing | $PATH_TO_MMSEG/configs/restorers/srcnn/srcnn_x4k915_g1_1000k_div2k.py | Y | Y | N | Y | N | -| ESRGAN | MMEditing | $PATH_TO_MMSEG/configs/restorers/esrgan/esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py | Y | Y | N | Y | N | -| DBNet | MMOCR | $PATH_TO_MMOCR/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py | Y | Y | Y | Y | N | -| CRNN | MMOCR | $PATH_TO_MMOCR/configs/textrecog/tps/crnn_tps_academic_dataset.py | Y | Y | Y | N | N | +| Model | codebase | OnnxRuntime | TensorRT | NCNN | PPL | OpenVINO | model config file(example) | +|--------------------|------------------|:-----------:|:--------:|:----:|:---:|:--------:|:--------------------------------------------------------------------------------------| +| RetinaNet | MMDetection | Y | Y | Y | Y | Y | $MMDET_DIR/configs/retinanet/retinanet_r50_fpn_1x_coco.py | +| Faster R-CNN | MMDetection | Y | Y | Y | Y | Y | $MMDET_DIR/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py | +| YOLOv3 | MMDetection | Y | Y | Y | Y | Y | $MMDET_DIR/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py | +| FCOS | MMDetection | Y | Y | Y | N | Y | $MMDET_DIR/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py | +| FSAF | MMDetection | Y | Y | Y | Y | Y | $MMDET_DIR/configs/fsaf/fsaf_r50_fpn_1x_coco.py | +| Mask R-CNN | MMDetection | Y | Y | N | Y | Y | $MMDET_DIR/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py | +| SSD | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/ssd/ssd300_coco.py | +| FoveaBox | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/foveabox/fovea_r50_fpn_4x4_1x_coco.py | +| ATSS | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/atss/atss_r50_fpn_1x_coco.py | +| Cascade R-CNN | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py | +| Cascade Mask R-CNN | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py | +| ResNet | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/resnet/resnet18_b32x8_imagenet.py | +| ResNeXt | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/resnext/resnext50_32x4d_b32x8_imagenet.py | +| SE-ResNet | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/seresnet/seresnet50_b32x8_imagenet.py | +| MobileNetV2 | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py | +| ShuffleNetV1 | MMClassification | Y | Y | N | Y | N | $MMCLS_DIR/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py | +| ShuffleNetV2 | MMClassification | Y | Y | N | Y | N | $MMCLS_DIR/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py | +| FCN | MMSegmentation | Y | Y | Y | Y | N | $MMSEG_DIR/configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py | +| PSPNet | MMSegmentation | Y | Y | N | Y | N | $MMSEG_DIR/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py | +| DeepLabV3 | MMSegmentation | Y | Y | Y | Y | N | $MMSEG_DIR/configs/deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py | +| DeepLabV3+ | MMSegmentation | Y | Y | Y | Y | N | $MMSEG_DIR/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py | +| SRCNN | MMEditing | Y | Y | N | Y | N | $MMSEG_DIR/configs/restorers/srcnn/srcnn_x4k915_g1_1000k_div2k.py | +| ESRGAN | MMEditing | Y | Y | N | Y | N | $MMSEG_DIR/configs/restorers/esrgan/esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py | +| DBNet | MMOCR | Y | Y | Y | Y | N | $MMOCR_DIR/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py | +| CRNN | MMOCR | Y | Y | Y | N | N | $MMOCR_DIR/configs/textrecog/tps/crnn_tps_academic_dataset.py | ### Reminders diff --git a/mmdeploy/mmcv/ops/nms.py b/mmdeploy/mmcv/ops/nms.py index 3523cfd57..7a99f5b07 100644 --- a/mmdeploy/mmcv/ops/nms.py +++ b/mmdeploy/mmcv/ops/nms.py @@ -20,7 +20,8 @@ class ONNXNMSop(torch.autograd.Function): for batch_id in range(batch_size): for cls_id in range(num_class): _boxes = boxes[batch_id, ...] - _scores = scores[batch_id, cls_id, ...] + # score_threshold=0 requires scores to be contiguous + _scores = scores[batch_id, cls_id, ...].contiguous() _, box_inds = nms( _boxes, _scores, diff --git a/mmdeploy/mmdet/models/dense_heads/__init__.py b/mmdeploy/mmdet/models/dense_heads/__init__.py index fb05b72b1..9234fb257 100644 --- a/mmdeploy/mmdet/models/dense_heads/__init__.py +++ b/mmdeploy/mmdet/models/dense_heads/__init__.py @@ -3,9 +3,11 @@ from .atss_head import get_bboxes_of_atss_head from .fcos_head import get_bboxes_of_fcos_head from .fovea_head import get_bboxes_of_fovea_head from .rpn_head import get_bboxes_of_rpn_head +from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn __all__ = [ 'get_bboxes_of_anchor_head', 'get_bboxes_of_fcos_head', 'get_bboxes_of_rpn_head', 'get_bboxes_of_fovea_head', - 'get_bboxes_of_atss_head' + 'get_bboxes_of_atss_head', 'yolov3_head__get_bboxes', + 'yolov3_head__get_bboxes__ncnn' ] diff --git a/mmdeploy/mmdet/models/dense_heads/fcos_head.py b/mmdeploy/mmdet/models/dense_heads/fcos_head.py index 7e2dd710c..93b7dceb2 100644 --- a/mmdeploy/mmdet/models/dense_heads/fcos_head.py +++ b/mmdeploy/mmdet/models/dense_heads/fcos_head.py @@ -102,10 +102,11 @@ def get_bboxes_of_fcos_head(ctx, max_output_boxes_per_class = post_params.max_output_boxes_per_class iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) score_threshold = cfg.get('score_thr', post_params.score_threshold) - nms_pre = cfg.get('deploy_nms_pre', -1) + pre_top_k = post_params.pre_top_k + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) return multiclass_nms(batch_mlvl_bboxes, batch_mlvl_scores, max_output_boxes_per_class, iou_threshold, - score_threshold, nms_pre, cfg.max_per_img) + score_threshold, pre_top_k, keep_top_k) @FUNCTION_REWRITER.register_rewriter( @@ -195,7 +196,8 @@ def get_bboxes_of_fcos_head_ncnn(ctx, max_output_boxes_per_class = post_params.max_output_boxes_per_class iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) score_threshold = cfg.get('score_thr', post_params.score_threshold) - nms_pre = cfg.get('deploy_nms_pre', -1) + pre_top_k = post_params.pre_top_k + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) return multiclass_nms(batch_mlvl_bboxes, batch_mlvl_scores, max_output_boxes_per_class, iou_threshold, - score_threshold, nms_pre, cfg.max_per_img) + score_threshold, pre_top_k, keep_top_k) diff --git a/mmdeploy/mmdet/models/dense_heads/yolo_head.py b/mmdeploy/mmdet/models/dense_heads/yolo_head.py new file mode 100644 index 000000000..e0c285404 --- /dev/null +++ b/mmdeploy/mmdet/models/dense_heads/yolo_head.py @@ -0,0 +1,309 @@ +import torch + +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.mmdet.core import multiclass_nms +from mmdeploy.mmdet.export import pad_with_value +from mmdeploy.utils import (Backend, get_backend, get_mmdet_params, + is_dynamic_shape) + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdet.models.YOLOV3Head.get_bboxes') +def yolov3_head__get_bboxes(ctx, + self, + pred_maps, + with_nms=True, + cfg=None, + **kwargs): + """Rewrite `get_bboxes` for default backend. + + Transform network output for a batch into bbox predictions. + + Args: + ctx: Context that contains original meta information. + self: Represent the instance of the original class. + pred_maps (list[Tensor]): Raw predictions for a batch of images. + cfg (mmcv.Config | None): Test / postprocessing configuration, + if None, test_cfg would be used. Default: None. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor, + where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch + size and the score between 0 and 1. The shape of the second + tensor in the tuple is (N, num_box), and each element + represents the class label of the corresponding box. + """ + is_dynamic_flag = is_dynamic_shape(ctx.cfg) + num_levels = len(pred_maps) + pred_maps_list = [pred_maps[i].detach() for i in range(num_levels)] + + cfg = self.test_cfg if cfg is None else cfg + assert len(pred_maps_list) == self.num_levels + + device = pred_maps_list[0].device + batch_size = pred_maps_list[0].shape[0] + + featmap_sizes = [ + pred_maps_list[i].shape[-2:] for i in range(self.num_levels) + ] + multi_lvl_anchors = self.anchor_generator.grid_anchors( + featmap_sizes, device) + pre_topk = cfg.get('nms_pre', -1) + multi_lvl_bboxes = [] + multi_lvl_cls_scores = [] + multi_lvl_conf_scores = [] + for i in range(self.num_levels): + # get some key info for current scale + pred_map = pred_maps_list[i] + stride = self.featmap_strides[i] + # (b,h, w, num_anchors*num_attrib) -> + # (b,h*w*num_anchors, num_attrib) + pred_map = pred_map.permute(0, 2, 3, + 1).reshape(batch_size, -1, self.num_attrib) + # Inplace operation like + # ```pred_map[..., :2] = \torch.sigmoid(pred_map[..., :2])``` + # would create constant tensor when exporting to onnx + pred_map_conf = torch.sigmoid(pred_map[..., :2]) + pred_map_rest = pred_map[..., 2:] + pred_map = torch.cat([pred_map_conf, pred_map_rest], dim=-1) + pred_map_boxes = pred_map[..., :4] + multi_lvl_anchor = multi_lvl_anchors[i] + # use static anchor if input shape is static + if not is_dynamic_flag: + multi_lvl_anchor = multi_lvl_anchor.data + multi_lvl_anchor = multi_lvl_anchor.unsqueeze(0).expand_as( + pred_map_boxes) + bbox_pred = self.bbox_coder.decode(multi_lvl_anchor, pred_map_boxes, + stride) + # conf and cls + conf_pred = torch.sigmoid(pred_map[..., 4]) + cls_pred = torch.sigmoid(pred_map[..., 5:]).view( + batch_size, -1, self.num_classes) # Cls pred one-hot. + + backend = get_backend(ctx.cfg) + # topk in tensorrt does not support shape 0: + _, topk_inds = conf_pred.topk(pre_topk) + batch_inds = torch.arange( + batch_size, device=device).view(-1, + 1).expand_as(topk_inds).long() + # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501 + transformed_inds = (bbox_pred.shape[1] * batch_inds + topk_inds) + bbox_pred = bbox_pred.reshape(-1, 4)[transformed_inds, :].reshape( + batch_size, -1, 4) + cls_pred = cls_pred.reshape( + -1, self.num_classes)[transformed_inds, :].reshape( + batch_size, -1, self.num_classes) + conf_pred = conf_pred.reshape(-1, 1)[transformed_inds].reshape( + batch_size, -1) + + # Save the result of current scale + multi_lvl_bboxes.append(bbox_pred) + multi_lvl_cls_scores.append(cls_pred) + multi_lvl_conf_scores.append(conf_pred) + + # Merge the results of different scales together + batch_mlvl_bboxes = torch.cat(multi_lvl_bboxes, dim=1) + batch_mlvl_scores = torch.cat(multi_lvl_cls_scores, dim=1) + batch_mlvl_conf_scores = torch.cat(multi_lvl_conf_scores, dim=1) + + post_params = get_mmdet_params(ctx.cfg) + + score_threshold = cfg.get('score_thr', post_params.score_threshold) + confidence_threshold = cfg.get('conf_thr', + post_params.confidence_threshold) + + # follow original pipeline of YOLOv3 + if confidence_threshold > 0: + mask = (batch_mlvl_conf_scores >= confidence_threshold).float() + batch_mlvl_conf_scores *= mask + if score_threshold > 0: + mask = (batch_mlvl_scores > score_threshold).float() + batch_mlvl_scores *= mask + + batch_mlvl_conf_scores = batch_mlvl_conf_scores.unsqueeze(2).expand_as( + batch_mlvl_scores) + batch_mlvl_scores = batch_mlvl_scores * batch_mlvl_conf_scores + + if with_nms: + max_output_boxes_per_class = post_params.max_output_boxes_per_class + iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + pre_top_k = post_params.pre_top_k + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) + # keep aligned with original pipeline, improve + # mAP by 1% for YOLOv3 in ONNX + score_threshold = 0 + return multiclass_nms( + batch_mlvl_bboxes, + batch_mlvl_scores, + max_output_boxes_per_class, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k) + else: + return batch_mlvl_bboxes, batch_mlvl_scores + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdet.models.YOLOV3Head.get_bboxes', backend='ncnn') +def yolov3_head__get_bboxes__ncnn(ctx, + self, + pred_maps, + with_nms=True, + cfg=None, + **kwargs): + """Rewrite `get_bboxes` for ncnn backend. + + Transform network output for a batch into bbox predictions. + + Args: + ctx: Context that contains original meta information. + self: Represent the instance of the original class. + pred_maps (list[Tensor]): Raw predictions for a batch of images. + cfg (mmcv.Config | None): Test / postprocessing configuration, + if None, test_cfg would be used. Default: None. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor, + where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch + size and the score between 0 and 1. The shape of the second + tensor in the tuple is (N, num_box), and each element + represents the class label of the corresponding box. + """ + num_levels = len(pred_maps) + pred_maps_list = [pred_maps[i].detach() for i in range(num_levels)] + + cfg = self.test_cfg if cfg is None else cfg + assert len(pred_maps_list) == self.num_levels + + device = pred_maps_list[0].device + batch_size = pred_maps_list[0].shape[0] + + featmap_sizes = [ + pred_maps_list[i].shape[-2:] for i in range(self.num_levels) + ] + multi_lvl_anchors = self.anchor_generator.grid_anchors( + featmap_sizes, device) + pre_topk = cfg.get('nms_pre', -1) + multi_lvl_bboxes = [] + multi_lvl_cls_scores = [] + multi_lvl_conf_scores = [] + for i in range(self.num_levels): + # get some key info for current scale + pred_map = pred_maps_list[i] + stride = self.featmap_strides[i] + # (b,h, w, num_anchors*num_attrib) -> + # (b,h*w*num_anchors, num_attrib) + pred_map = pred_map.permute(0, 2, 3, + 1).reshape(batch_size, -1, self.num_attrib) + # Inplace operation like + # ```pred_map[..., :2] = \torch.sigmoid(pred_map[..., :2])``` + # would create constant tensor when exporting to onnx + pred_map_conf = torch.sigmoid(pred_map[..., :2]) + pred_map_rest = pred_map[..., 2:] + # dim must be written as 2, but not -1, because ncnn implicit batch + # mechanism. + pred_map = torch.cat([pred_map_conf, pred_map_rest], dim=2) + pred_map_boxes = pred_map[..., :4] + multi_lvl_anchor = multi_lvl_anchors[i] + # use static anchor if input shape is static + multi_lvl_anchor = multi_lvl_anchor.unsqueeze(0).expand_as( + pred_map_boxes).data + + bbox_pred = self.bbox_coder.decode(multi_lvl_anchor, pred_map_boxes, + stride) + # conf and cls + conf_pred = torch.sigmoid(pred_map[..., 4]) + cls_pred = torch.sigmoid(pred_map[..., 5:]).view( + batch_size, -1, self.num_classes) # Cls pred one-hot. + + if pre_topk > 0: + _, topk_inds = conf_pred.topk(pre_topk) + topk_inds = topk_inds.view(-1) + bbox_pred = bbox_pred[:, topk_inds, :] + cls_pred = cls_pred[:, topk_inds, :] + conf_pred = conf_pred[:, topk_inds] + + # Save the result of current scale + multi_lvl_bboxes.append(bbox_pred) + multi_lvl_cls_scores.append(cls_pred) + multi_lvl_conf_scores.append(conf_pred) + + # Merge the results of different scales together + batch_mlvl_bboxes = torch.cat(multi_lvl_bboxes, dim=1) + batch_mlvl_scores = torch.cat(multi_lvl_cls_scores, dim=1) + batch_mlvl_conf_scores = torch.cat(multi_lvl_conf_scores, dim=1) + + post_params = get_mmdet_params(ctx.cfg) + + score_threshold = cfg.get('score_thr', post_params.score_threshold) + confidence_threshold = cfg.get('conf_thr', + post_params.confidence_threshold) + + # helper function for creating Threshold op + def _create_threshold(x, thresh): + + class ThresholdOp(torch.autograd.Function): + """Create Threshold op.""" + + @staticmethod + def forward(ctx, x, threshold): + return x > threshold + + @staticmethod + def symbolic(g, x, threshold): + return g.op( + 'mmdeploy::Threshold', x, threshold_f=threshold, outputs=1) + + return ThresholdOp.apply(x, thresh) + + # follow original pipeline of YOLOv3 + if confidence_threshold > 0: + mask = _create_threshold(batch_mlvl_conf_scores, + confidence_threshold).float() + batch_mlvl_conf_scores *= mask + if score_threshold > 0: + mask = _create_threshold(batch_mlvl_scores, score_threshold).float() + batch_mlvl_scores *= mask + + # NCNN broadcast needs the same in channel dimension. + _batch_mlvl_conf_scores = batch_mlvl_conf_scores.unsqueeze(2).unsqueeze(3) + _batch_mlvl_scores = batch_mlvl_scores.unsqueeze(3) + batch_mlvl_scores = (_batch_mlvl_scores * _batch_mlvl_conf_scores).reshape( + batch_mlvl_scores.shape) + # Although batch_mlvl_bboxes already has the shape of + # (batch_size, -1, 4), ncnn implicit batch mechanism in the model and + # ncnn channel alignment would result in a shape of + # (batch_size, -1, 4, 1). So, we need a reshape op to ensure the + # batch_mlvl_bboxes shape is right. + batch_mlvl_bboxes = batch_mlvl_bboxes.reshape(batch_size, -1, 4) + + if with_nms: + max_output_boxes_per_class = post_params.max_output_boxes_per_class + iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + pre_top_k = post_params.pre_top_k + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) + # keep aligned with original pipeline, improve + # mAP by 1% for YOLOv3 in ONNX + score_threshold = 0 + return multiclass_nms( + batch_mlvl_bboxes, + batch_mlvl_scores, + max_output_boxes_per_class, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k) + else: + return batch_mlvl_bboxes, batch_mlvl_scores diff --git a/tests/test_mmdet/test_mmdet_models.py b/tests/test_mmdet/test_mmdet_models.py index 16c6c303f..fe6dbbf87 100644 --- a/tests/test_mmdet/test_mmdet_models.py +++ b/tests/test_mmdet/test_mmdet_models.py @@ -786,3 +786,97 @@ def test_cascade_roi_head_with_mask(backend_type): assert np.allclose( expected_segm_results, segm_results, rtol=1e-03, atol=1e-05), 'segm_results do not match.' + + +def get_yolov3_head_model(): + """yolov3 Head Config.""" + test_cfg = mmcv.Config( + dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + conf_thr=0.005, + nms=dict(type='nms', iou_threshold=0.45), + max_per_img=100)) + from mmdet.models import YOLOV3Head + model = YOLOV3Head( + num_classes=4, + in_channels=[16, 8, 4], + out_channels=[32, 16, 8], + test_cfg=test_cfg) + + model.requires_grad_(False) + return model + + +@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn', 'openvino']) +def test_yolov3_head_get_bboxes(backend_type): + """Test get_bboxes rewrite of yolov3 head.""" + pytest.importorskip(backend_type, reason=f'requires {backend_type}') + yolov3_head = get_yolov3_head_model() + yolov3_head.cpu().eval() + s = 128 + img_metas = [{ + 'scale_factor': np.ones(4), + 'pad_shape': (s, s, 3), + 'img_shape': (s, s, 3) + }] + + output_names = ['dets', 'labels'] + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend_type), + onnx_config=dict(output_names=output_names, input_shape=None), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.45, + confidence_threshold=0.005, + max_output_boxes_per_class=200, + pre_top_k=-1, + keep_top_k=100, + background_label_id=-1, + )))) + + seed_everything(1234) + pred_maps = [ + torch.rand(1, 27, 5, 5), + torch.rand(1, 27, 10, 10), + torch.rand(1, 27, 20, 20) + ] + # to get outputs of pytorch model + model_inputs = {'pred_maps': pred_maps, 'img_metas': img_metas} + model_outputs = get_model_outputs(yolov3_head, 'get_bboxes', model_inputs) + + # to get outputs of onnx model after rewrite + wrapped_model = WrapModel( + yolov3_head, 'get_bboxes', img_metas=img_metas[0], with_nms=True) + rewrite_inputs = { + 'pred_maps': pred_maps, + } + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + if is_backend_output: + if isinstance(rewrite_outputs, dict): + rewrite_outputs = [ + value for name, value in rewrite_outputs.items() + if name in output_names + ] + for model_output, rewrite_output in zip(model_outputs[0], + rewrite_outputs): + model_output = model_output.squeeze().cpu().numpy() + rewrite_output = rewrite_output.squeeze() + # hard code to make two tensors with the same shape + # rewrite and original codes applied different nms strategy + assert np.allclose( + model_output[:rewrite_output.shape[0]], + rewrite_output, + rtol=1e-03, + atol=1e-05) + else: + assert rewrite_outputs is not None