[Feature] Support YOLOv5 YOLOv6 YOLOX Deploy in mmdeploy (#199)

* Support YOLOv5 YOLOv6 YOLOX Deploy in mmdeploy

* Fix lint

* Rename _class to detector_type

* Add some common

* fix lint

Co-authored-by: huanghaian <huanghaian@sensetime.com>
pull/249/head
tripleMu 2022-11-01 17:08:42 +08:00 committed by Haian Huang(深度眸)
parent 190ee5aaa7
commit 275beec782
9 changed files with 184 additions and 24 deletions

View File

@ -6,7 +6,8 @@ backend_config = dict(
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 320, 320],
min_shape=[1, 3, 192, 192],
opt_shape=[1, 3, 640, 640],
max_shape=[1, 3, 640, 640])))
max_shape=[1, 3, 960, 960])))
])
use_efficientnms = False # whether to replace TRTBatchedNMS plugin with EfficientNMS plugin # noqa E501

View File

@ -11,3 +11,4 @@ backend_config = dict(
opt_shape=[1, 3, 640, 640],
max_shape=[1, 3, 640, 640])))
])
use_efficientnms = False # whether to replace TRTBatchedNMS plugin with EfficientNMS plugin # noqa E501

View File

@ -7,8 +7,9 @@ backend_config = dict(
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 320, 320],
min_shape=[1, 3, 192, 192],
opt_shape=[1, 3, 640, 640],
max_shape=[1, 3, 640, 640])))
max_shape=[1, 3, 960, 960])))
],
calib_config=dict(create_calib=True, calib_file='calib_data.h5'))
use_efficientnms = False # whether to replace TRTBatchedNMS plugin with EfficientNMS plugin # noqa E501

View File

@ -13,3 +13,4 @@ backend_config = dict(
max_shape=[1, 3, 640, 640])))
],
calib_config=dict(create_calib=True, calib_file='calib_data.h5'))
use_efficientnms = False # whether to replace TRTBatchedNMS plugin with EfficientNMS plugin # noqa E501

View File

@ -6,7 +6,8 @@ backend_config = dict(
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 320, 320],
min_shape=[1, 3, 192, 192],
opt_shape=[1, 3, 640, 640],
max_shape=[1, 3, 640, 640])))
max_shape=[1, 3, 960, 960])))
])
use_efficientnms = False # whether to replace TRTBatchedNMS plugin with EfficientNMS plugin # noqa E501

View File

@ -11,3 +11,4 @@ backend_config = dict(
opt_shape=[1, 3, 640, 640],
max_shape=[1, 3, 640, 640])))
])
use_efficientnms = False # whether to replace TRTBatchedNMS plugin with EfficientNMS plugin # noqa E501

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from functools import partial
from typing import List, Optional, Tuple
import torch
@ -10,6 +11,28 @@ from mmengine.config import ConfigDict
from mmengine.structures import InstanceData
from torch import Tensor
from mmyolo.deploy.models.layers import efficient_nms
from mmyolo.models.dense_heads import YOLOv5Head
def yolov5_bbox_decoder(priors, bbox_preds, stride):
bbox_preds = bbox_preds.sigmoid()
x_center = (priors[..., 0] + priors[..., 2]) * 0.5
y_center = (priors[..., 1] + priors[..., 3]) * 0.5
w = priors[..., 2] - priors[..., 0]
h = priors[..., 3] - priors[..., 1]
x_center_pred = (bbox_preds[..., 0] - 0.5) * 2 * stride + x_center
y_center_pred = (bbox_preds[..., 1] - 0.5) * 2 * stride + y_center
w_pred = (bbox_preds[..., 2] * 2)**2 * w
h_pred = (bbox_preds[..., 3] * 2)**2 * h
decoded_bboxes = torch.stack(
[x_center_pred, y_center_pred, w_pred, h_pred], dim=-1)
return decoded_bboxes
@FUNCTION_REWRITER.register_rewriter(
func_name='mmyolo.models.dense_heads.yolov5_head.'
@ -18,7 +41,7 @@ def yolov5_head__predict_by_feat(ctx,
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
objectnesses: Optional[List[Tensor]],
objectnesses: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = False,
@ -51,6 +74,20 @@ def yolov5_head__predict_by_feat(ctx,
tensor in the tuple is (N, num_box), and each element
represents the class label of the corresponding box.
"""
detector_type = type(self)
deploy_cfg = ctx.cfg
use_efficientnms = deploy_cfg.get('use_efficientnms', False)
dtype = cls_scores[0].dtype
device = cls_scores[0].device
bbox_decoder = self.bbox_coder.decode
nms_func = multiclass_nms
if use_efficientnms:
if detector_type is YOLOv5Head:
nms_func = partial(efficient_nms, box_coding=0)
bbox_decoder = yolov5_bbox_decoder
else:
nms_func = efficient_nms
assert len(cls_scores) == len(bbox_preds)
cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)
@ -59,7 +96,8 @@ def yolov5_head__predict_by_feat(ctx,
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes, dtype=cls_scores[0].dtype, device=cls_scores[0].device)
featmap_sizes, dtype=dtype, device=device)
flatten_priors = torch.cat(mlvl_priors)
mlvl_strides = [
@ -69,33 +107,36 @@ def yolov5_head__predict_by_feat(ctx,
for featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
]
flatten_stride = torch.cat(mlvl_strides)
# flatten cls_scores, bbox_preds and objectness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes)
for cls_score in cls_scores
]
cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]
flatten_objectness = [
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
for objectness in objectnesses
]
cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
bboxes = self.bbox_coder.decode(flatten_priors[None], flatten_bbox_preds,
flatten_stride)
# directly multiply score factor and feed to nms
scores = cls_scores * (flatten_objectness.unsqueeze(-1))
if objectnesses is not None:
flatten_objectness = [
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
for objectness in objectnesses
]
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1))
scores = cls_scores
bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds,
flatten_stride)
if not with_nms:
return bboxes, scores
deploy_cfg = ctx.cfg
post_params = get_post_processing_params(deploy_cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
@ -103,6 +144,5 @@ def yolov5_head__predict_by_feat(ctx,
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
return multiclass_nms(bboxes, scores, max_output_boxes_per_class,
iou_threshold, score_threshold, pre_top_k,
keep_top_k)
return nms_func(bboxes, scores, max_output_boxes_per_class, iou_threshold,
score_threshold, pre_top_k, keep_top_k)

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bbox_nms import efficient_nms
__all__ = ['efficient_nms']

View File

@ -0,0 +1,110 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdeploy.core import mark
from torch import Tensor
def _efficient_nms(
boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = 100,
box_coding: int = 0,
):
"""Wrapper for `efficient_nms` with TensorRT.
Args:
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
scores (Tensor): The detection scores of shape
[N, num_boxes, num_classes].
max_output_boxes_per_class (int): Maximum number of output
boxes per class of nms. Defaults to 1000.
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
score_threshold (float): score threshold of nms.
Defaults to 0.05.
pre_top_k (int): Number of top K boxes to keep before nms.
Defaults to -1.
keep_top_k (int): Number of top K boxes to keep after nms.
Defaults to -1.
box_coding (int): Bounding boxes format for nms.
Defaults to 0 means [x, y, w, h].
Set to 1 means [x1, y1 ,x2, y2].
Returns:
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
and `labels` of shape [N, num_det].
"""
boxes = boxes if boxes.dim() == 4 else boxes.unsqueeze(2)
_, det_boxes, det_scores, labels = TRTEfficientNMSop.apply(
boxes, scores, -1, box_coding, iou_threshold, keep_top_k, '1', 0,
score_threshold)
dets = torch.cat([det_boxes, det_scores.unsqueeze(2)], -1)
# retain shape info
batch_size = boxes.size(0)
dets_shape = dets.shape
label_shape = labels.shape
dets = dets.reshape([batch_size, *dets_shape[1:]])
labels = labels.reshape([batch_size, *label_shape[1:]])
return dets, labels
@mark('efficient_nms', inputs=['boxes', 'scores'], outputs=['dets', 'labels'])
def efficient_nms(*args, **kwargs):
"""Wrapper function for `_efficient_nms`."""
return _efficient_nms(*args, **kwargs)
class TRTEfficientNMSop(torch.autograd.Function):
@staticmethod
def forward(
ctx,
boxes,
scores,
background_class=-1,
box_coding=0,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version='1',
score_activation=0,
score_threshold=0.25,
):
batch_size, num_boxes, num_classes = scores.shape
num_det = torch.randint(
0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
det_scores = torch.randn(batch_size, max_output_boxes)
det_classes = torch.randint(
0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
return num_det, det_boxes, det_scores, det_classes
@staticmethod
def symbolic(g,
boxes,
scores,
background_class=-1,
box_coding=0,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version='1',
score_activation=0,
score_threshold=0.25):
out = g.op(
'TRT::EfficientNMS_TRT',
boxes,
scores,
background_class_i=background_class,
box_coding_i=box_coding,
iou_threshold_f=iou_threshold,
max_output_boxes_i=max_output_boxes,
plugin_version_s=plugin_version,
score_activation_i=score_activation,
score_threshold_f=score_threshold,
outputs=4)
nums, boxes, scores, classes = out
return nums, boxes, scores, classes