mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Support YOLOv5 YOLOv6 YOLOX Deploy in mmdeploy (#199)
* Support YOLOv5 YOLOv6 YOLOX Deploy in mmdeploy * Fix lint * Rename _class to detector_type * Add some common * fix lint Co-authored-by: huanghaian <huanghaian@sensetime.com>pull/249/head
parent
190ee5aaa7
commit
275beec782
|
@ -6,7 +6,8 @@ backend_config = dict(
|
|||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 320, 320],
|
||||
min_shape=[1, 3, 192, 192],
|
||||
opt_shape=[1, 3, 640, 640],
|
||||
max_shape=[1, 3, 640, 640])))
|
||||
max_shape=[1, 3, 960, 960])))
|
||||
])
|
||||
use_efficientnms = False # whether to replace TRTBatchedNMS plugin with EfficientNMS plugin # noqa E501
|
|
@ -11,3 +11,4 @@ backend_config = dict(
|
|||
opt_shape=[1, 3, 640, 640],
|
||||
max_shape=[1, 3, 640, 640])))
|
||||
])
|
||||
use_efficientnms = False # whether to replace TRTBatchedNMS plugin with EfficientNMS plugin # noqa E501
|
||||
|
|
|
@ -7,8 +7,9 @@ backend_config = dict(
|
|||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 320, 320],
|
||||
min_shape=[1, 3, 192, 192],
|
||||
opt_shape=[1, 3, 640, 640],
|
||||
max_shape=[1, 3, 640, 640])))
|
||||
max_shape=[1, 3, 960, 960])))
|
||||
],
|
||||
calib_config=dict(create_calib=True, calib_file='calib_data.h5'))
|
||||
use_efficientnms = False # whether to replace TRTBatchedNMS plugin with EfficientNMS plugin # noqa E501
|
|
@ -13,3 +13,4 @@ backend_config = dict(
|
|||
max_shape=[1, 3, 640, 640])))
|
||||
],
|
||||
calib_config=dict(create_calib=True, calib_file='calib_data.h5'))
|
||||
use_efficientnms = False # whether to replace TRTBatchedNMS plugin with EfficientNMS plugin # noqa E501
|
||||
|
|
|
@ -6,7 +6,8 @@ backend_config = dict(
|
|||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 320, 320],
|
||||
min_shape=[1, 3, 192, 192],
|
||||
opt_shape=[1, 3, 640, 640],
|
||||
max_shape=[1, 3, 640, 640])))
|
||||
max_shape=[1, 3, 960, 960])))
|
||||
])
|
||||
use_efficientnms = False # whether to replace TRTBatchedNMS plugin with EfficientNMS plugin # noqa E501
|
|
@ -11,3 +11,4 @@ backend_config = dict(
|
|||
opt_shape=[1, 3, 640, 640],
|
||||
max_shape=[1, 3, 640, 640])))
|
||||
])
|
||||
use_efficientnms = False # whether to replace TRTBatchedNMS plugin with EfficientNMS plugin # noqa E501
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
@ -10,6 +11,28 @@ from mmengine.config import ConfigDict
|
|||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmyolo.deploy.models.layers import efficient_nms
|
||||
from mmyolo.models.dense_heads import YOLOv5Head
|
||||
|
||||
|
||||
def yolov5_bbox_decoder(priors, bbox_preds, stride):
|
||||
bbox_preds = bbox_preds.sigmoid()
|
||||
|
||||
x_center = (priors[..., 0] + priors[..., 2]) * 0.5
|
||||
y_center = (priors[..., 1] + priors[..., 3]) * 0.5
|
||||
w = priors[..., 2] - priors[..., 0]
|
||||
h = priors[..., 3] - priors[..., 1]
|
||||
|
||||
x_center_pred = (bbox_preds[..., 0] - 0.5) * 2 * stride + x_center
|
||||
y_center_pred = (bbox_preds[..., 1] - 0.5) * 2 * stride + y_center
|
||||
w_pred = (bbox_preds[..., 2] * 2)**2 * w
|
||||
h_pred = (bbox_preds[..., 3] * 2)**2 * h
|
||||
|
||||
decoded_bboxes = torch.stack(
|
||||
[x_center_pred, y_center_pred, w_pred, h_pred], dim=-1)
|
||||
|
||||
return decoded_bboxes
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmyolo.models.dense_heads.yolov5_head.'
|
||||
|
@ -18,7 +41,7 @@ def yolov5_head__predict_by_feat(ctx,
|
|||
self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
objectnesses: Optional[List[Tensor]],
|
||||
objectnesses: Optional[List[Tensor]] = None,
|
||||
batch_img_metas: Optional[List[dict]] = None,
|
||||
cfg: Optional[ConfigDict] = None,
|
||||
rescale: bool = False,
|
||||
|
@ -51,6 +74,20 @@ def yolov5_head__predict_by_feat(ctx,
|
|||
tensor in the tuple is (N, num_box), and each element
|
||||
represents the class label of the corresponding box.
|
||||
"""
|
||||
detector_type = type(self)
|
||||
deploy_cfg = ctx.cfg
|
||||
use_efficientnms = deploy_cfg.get('use_efficientnms', False)
|
||||
dtype = cls_scores[0].dtype
|
||||
device = cls_scores[0].device
|
||||
bbox_decoder = self.bbox_coder.decode
|
||||
nms_func = multiclass_nms
|
||||
if use_efficientnms:
|
||||
if detector_type is YOLOv5Head:
|
||||
nms_func = partial(efficient_nms, box_coding=0)
|
||||
bbox_decoder = yolov5_bbox_decoder
|
||||
else:
|
||||
nms_func = efficient_nms
|
||||
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
cfg = copy.deepcopy(cfg)
|
||||
|
@ -59,7 +96,8 @@ def yolov5_head__predict_by_feat(ctx,
|
|||
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
|
||||
|
||||
mlvl_priors = self.prior_generator.grid_priors(
|
||||
featmap_sizes, dtype=cls_scores[0].dtype, device=cls_scores[0].device)
|
||||
featmap_sizes, dtype=dtype, device=device)
|
||||
|
||||
flatten_priors = torch.cat(mlvl_priors)
|
||||
|
||||
mlvl_strides = [
|
||||
|
@ -69,33 +107,36 @@ def yolov5_head__predict_by_feat(ctx,
|
|||
for featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
|
||||
]
|
||||
flatten_stride = torch.cat(mlvl_strides)
|
||||
|
||||
# flatten cls_scores, bbox_preds and objectness
|
||||
flatten_cls_scores = [
|
||||
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes)
|
||||
for cls_score in cls_scores
|
||||
]
|
||||
cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
|
||||
|
||||
flatten_bbox_preds = [
|
||||
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
||||
for bbox_pred in bbox_preds
|
||||
]
|
||||
|
||||
flatten_objectness = [
|
||||
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
|
||||
for objectness in objectnesses
|
||||
]
|
||||
|
||||
cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
|
||||
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
|
||||
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
|
||||
bboxes = self.bbox_coder.decode(flatten_priors[None], flatten_bbox_preds,
|
||||
flatten_stride)
|
||||
|
||||
# directly multiply score factor and feed to nms
|
||||
scores = cls_scores * (flatten_objectness.unsqueeze(-1))
|
||||
if objectnesses is not None:
|
||||
flatten_objectness = [
|
||||
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
|
||||
for objectness in objectnesses
|
||||
]
|
||||
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
|
||||
cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1))
|
||||
|
||||
scores = cls_scores
|
||||
|
||||
bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds,
|
||||
flatten_stride)
|
||||
|
||||
if not with_nms:
|
||||
return bboxes, scores
|
||||
deploy_cfg = ctx.cfg
|
||||
|
||||
post_params = get_post_processing_params(deploy_cfg)
|
||||
max_output_boxes_per_class = post_params.max_output_boxes_per_class
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
|
@ -103,6 +144,5 @@ def yolov5_head__predict_by_feat(ctx,
|
|||
pre_top_k = post_params.pre_top_k
|
||||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||
|
||||
return multiclass_nms(bboxes, scores, max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold, pre_top_k,
|
||||
keep_top_k)
|
||||
return nms_func(bboxes, scores, max_output_boxes_per_class, iou_threshold,
|
||||
score_threshold, pre_top_k, keep_top_k)
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .bbox_nms import efficient_nms
|
||||
|
||||
__all__ = ['efficient_nms']
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmdeploy.core import mark
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def _efficient_nms(
|
||||
boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
iou_threshold: float = 0.5,
|
||||
score_threshold: float = 0.05,
|
||||
pre_top_k: int = -1,
|
||||
keep_top_k: int = 100,
|
||||
box_coding: int = 0,
|
||||
):
|
||||
"""Wrapper for `efficient_nms` with TensorRT.
|
||||
|
||||
Args:
|
||||
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
|
||||
scores (Tensor): The detection scores of shape
|
||||
[N, num_boxes, num_classes].
|
||||
max_output_boxes_per_class (int): Maximum number of output
|
||||
boxes per class of nms. Defaults to 1000.
|
||||
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
|
||||
score_threshold (float): score threshold of nms.
|
||||
Defaults to 0.05.
|
||||
pre_top_k (int): Number of top K boxes to keep before nms.
|
||||
Defaults to -1.
|
||||
keep_top_k (int): Number of top K boxes to keep after nms.
|
||||
Defaults to -1.
|
||||
box_coding (int): Bounding boxes format for nms.
|
||||
Defaults to 0 means [x, y, w, h].
|
||||
Set to 1 means [x1, y1 ,x2, y2].
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
|
||||
and `labels` of shape [N, num_det].
|
||||
"""
|
||||
boxes = boxes if boxes.dim() == 4 else boxes.unsqueeze(2)
|
||||
_, det_boxes, det_scores, labels = TRTEfficientNMSop.apply(
|
||||
boxes, scores, -1, box_coding, iou_threshold, keep_top_k, '1', 0,
|
||||
score_threshold)
|
||||
dets = torch.cat([det_boxes, det_scores.unsqueeze(2)], -1)
|
||||
|
||||
# retain shape info
|
||||
batch_size = boxes.size(0)
|
||||
|
||||
dets_shape = dets.shape
|
||||
label_shape = labels.shape
|
||||
dets = dets.reshape([batch_size, *dets_shape[1:]])
|
||||
labels = labels.reshape([batch_size, *label_shape[1:]])
|
||||
return dets, labels
|
||||
|
||||
|
||||
@mark('efficient_nms', inputs=['boxes', 'scores'], outputs=['dets', 'labels'])
|
||||
def efficient_nms(*args, **kwargs):
|
||||
"""Wrapper function for `_efficient_nms`."""
|
||||
return _efficient_nms(*args, **kwargs)
|
||||
|
||||
|
||||
class TRTEfficientNMSop(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
boxes,
|
||||
scores,
|
||||
background_class=-1,
|
||||
box_coding=0,
|
||||
iou_threshold=0.45,
|
||||
max_output_boxes=100,
|
||||
plugin_version='1',
|
||||
score_activation=0,
|
||||
score_threshold=0.25,
|
||||
):
|
||||
batch_size, num_boxes, num_classes = scores.shape
|
||||
num_det = torch.randint(
|
||||
0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
|
||||
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
|
||||
det_scores = torch.randn(batch_size, max_output_boxes)
|
||||
det_classes = torch.randint(
|
||||
0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
|
||||
return num_det, det_boxes, det_scores, det_classes
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g,
|
||||
boxes,
|
||||
scores,
|
||||
background_class=-1,
|
||||
box_coding=0,
|
||||
iou_threshold=0.45,
|
||||
max_output_boxes=100,
|
||||
plugin_version='1',
|
||||
score_activation=0,
|
||||
score_threshold=0.25):
|
||||
out = g.op(
|
||||
'TRT::EfficientNMS_TRT',
|
||||
boxes,
|
||||
scores,
|
||||
background_class_i=background_class,
|
||||
box_coding_i=box_coding,
|
||||
iou_threshold_f=iou_threshold,
|
||||
max_output_boxes_i=max_output_boxes,
|
||||
plugin_version_s=plugin_version,
|
||||
score_activation_i=score_activation,
|
||||
score_threshold_f=score_threshold,
|
||||
outputs=4)
|
||||
nums, boxes, scores, classes = out
|
||||
return nums, boxes, scores, classes
|
Loading…
Reference in New Issue