Support exporting to ONNX for Faster R-CNN (#20)
* update rewriter for rpn head * new nms implement * update bbox_head to support fasterrcnn with tensorrt * rewrite SingleRoIExtractor * remove unnecessary import * fix fcos and bbox_nms * resolve comments * resolve comments Co-authored-by: grimoire <yaoqian@sensetime.com>pull/12/head
parent
b9e64f9e1c
commit
10a2385d01
|
@ -114,10 +114,7 @@ def init_backend_model(model_files: Sequence[str],
|
|||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
||||
|
||||
def get_classes_from_config(
|
||||
codebase: str,
|
||||
model_cfg: Union[str, mmcv.Config],
|
||||
):
|
||||
def get_classes_from_config(codebase: str, model_cfg: Union[str, mmcv.Config]):
|
||||
model_cfg_str = model_cfg
|
||||
if codebase == 'mmdet':
|
||||
if module_exist(codebase):
|
||||
|
|
|
@ -15,7 +15,9 @@ class DummyONNXNMSop(torch.autograd.Function):
|
|||
score_threshold):
|
||||
batch_size, num_class, num_box = scores.shape
|
||||
# create dummy indices of nms output
|
||||
num_fake_det = 2
|
||||
# number of detection should be large enough to
|
||||
# cover all layers of fpn
|
||||
num_fake_det = 100
|
||||
batch_inds = torch.randint(batch_size, (num_fake_det, 1))
|
||||
cls_inds = torch.randint(num_class, (num_fake_det, 1))
|
||||
box_inds = torch.randint(num_box, (num_fake_det, 1))
|
||||
|
|
|
@ -47,8 +47,7 @@ def _multiclass_nms(boxes,
|
|||
iou_threshold=0.5,
|
||||
score_threshold=0.05,
|
||||
pre_top_k=-1,
|
||||
keep_top_k=-1,
|
||||
labels=None):
|
||||
keep_top_k=-1):
|
||||
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
|
||||
|
||||
This function helps exporting to onnx with batch and multiclass NMS op.
|
||||
|
@ -69,9 +68,6 @@ def _multiclass_nms(boxes,
|
|||
Defaults to -1.
|
||||
keep_top_k (int): Number of top K boxes to keep after nms.
|
||||
Defaults to -1.
|
||||
labels (Tensor, optional): It not None, explicit labels would be used.
|
||||
Otherwise, labels would be automatically generated using
|
||||
num_classed. Defaults to None.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: dets of shape [N, num_det, 5] and class labels
|
||||
|
@ -81,28 +77,20 @@ def _multiclass_nms(boxes,
|
|||
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
|
||||
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
|
||||
batch_size = scores.shape[0]
|
||||
num_class = scores.shape[2]
|
||||
|
||||
if pre_top_k > 0:
|
||||
max_scores, _ = scores.max(-1)
|
||||
_, topk_inds = max_scores.topk(pre_top_k)
|
||||
batch_inds = torch.arange(batch_size).view(
|
||||
-1, 1).expand_as(topk_inds).long()
|
||||
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
|
||||
boxes = boxes[batch_inds, topk_inds, :]
|
||||
scores = scores[batch_inds, topk_inds, :]
|
||||
if labels is not None:
|
||||
labels = labels[batch_inds, topk_inds]
|
||||
|
||||
scores = scores.permute(0, 2, 1)
|
||||
selected_indices = DummyONNXNMSop.apply(boxes, scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold)
|
||||
|
||||
if labels is None:
|
||||
labels = torch.arange(num_class, dtype=torch.long).to(scores.device)
|
||||
labels = labels.view(1, num_class, 1).expand_as(scores)
|
||||
|
||||
dets, labels = select_nms_index(
|
||||
scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k)
|
||||
return dets, labels
|
||||
|
@ -118,8 +106,7 @@ def multiclass_nms_static(ctx,
|
|||
iou_threshold=0.5,
|
||||
score_threshold=0.05,
|
||||
pre_top_k=-1,
|
||||
keep_top_k=-1,
|
||||
labels=None):
|
||||
keep_top_k=-1):
|
||||
boxes = boxes if boxes.dim() == 4 else boxes.unsqueeze(2)
|
||||
keep_top_k = max_output_boxes_per_class if keep_top_k < 0 else min(
|
||||
max_output_boxes_per_class, keep_top_k)
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
from .anchor_head import get_bboxes_of_anchor_head
|
||||
from .fcos_head import get_bboxes_of_fcos_head
|
||||
from .rpn_head import get_bboxes_of_rpn_head
|
||||
|
||||
__all__ = ['get_bboxes_of_anchor_head', 'get_bboxes_of_fcos_head']
|
||||
__all__ = [
|
||||
'get_bboxes_of_anchor_head', 'get_bboxes_of_fcos_head',
|
||||
'get_bboxes_of_rpn_head'
|
||||
]
|
||||
|
|
|
@ -7,13 +7,11 @@ from mmdeploy.utils import is_dynamic_shape
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.AnchorHead.get_bboxes')
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.RetinaHead.get_bboxes')
|
||||
def get_bboxes_of_anchor_head(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_shape,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
|
@ -36,7 +34,8 @@ def get_bboxes_of_anchor_head(ctx,
|
|||
pre_topk = cfg.get('nms_pre', -1)
|
||||
|
||||
# loop over features, decode boxes
|
||||
mlvl_bboxes = []
|
||||
mlvl_valid_bboxes = []
|
||||
mlvl_valid_anchors = []
|
||||
mlvl_scores = []
|
||||
for level_id, cls_score, bbox_pred, anchors in zip(
|
||||
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds, mlvl_anchors):
|
||||
|
@ -74,20 +73,23 @@ def get_bboxes_of_anchor_head(ctx,
|
|||
# BG cat_id: num_class
|
||||
max_scores, _ = scores[..., :-1].max(-1)
|
||||
_, topk_inds = max_scores.topk(pre_topk)
|
||||
batch_inds = torch.arange(batch_size).view(-1,
|
||||
1).expand_as(topk_inds)
|
||||
batch_inds = torch.arange(
|
||||
batch_size, device=device).view(-1, 1).expand_as(topk_inds)
|
||||
anchors = anchors[batch_inds, topk_inds, :]
|
||||
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
|
||||
scores = scores[batch_inds, topk_inds, :]
|
||||
|
||||
bboxes = self.bbox_coder.decode(
|
||||
anchors, bbox_pred, max_shape=img_shape)
|
||||
mlvl_bboxes.append(bboxes)
|
||||
mlvl_valid_bboxes.append(bbox_pred)
|
||||
mlvl_scores.append(scores)
|
||||
mlvl_valid_anchors.append(anchors)
|
||||
|
||||
batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
|
||||
batch_mlvl_valid_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
|
||||
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
|
||||
|
||||
batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
|
||||
batch_mlvl_bboxes = self.bbox_coder.decode(
|
||||
batch_mlvl_anchors,
|
||||
batch_mlvl_valid_bboxes,
|
||||
max_shape=img_metas['img_shape'])
|
||||
# ignore background class
|
||||
if not self.use_sigmoid_cls:
|
||||
batch_mlvl_scores = batch_mlvl_scores[..., :self.num_classes]
|
||||
|
|
|
@ -12,7 +12,7 @@ def get_bboxes_of_fcos_head(ctx,
|
|||
cls_scores,
|
||||
bbox_preds,
|
||||
centernesses,
|
||||
img_shape,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
|
@ -87,7 +87,7 @@ def get_bboxes_of_fcos_head(ctx,
|
|||
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
|
||||
batch_mlvl_centerness = torch.cat(mlvl_centerness, dim=1)
|
||||
batch_mlvl_bboxes = distance2bbox(
|
||||
batch_mlvl_points, batch_mlvl_bboxes, max_shape=img_shape)
|
||||
batch_mlvl_points, batch_mlvl_bboxes, max_shape=img_metas['img_shape'])
|
||||
|
||||
if not with_nms:
|
||||
return batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.mmdet.core import multiclass_nms
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter('mmdet.models.RPNHead.get_bboxes')
|
||||
def get_bboxes_of_rpn_head(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
num_levels = len(cls_scores)
|
||||
|
||||
device = cls_scores[0].device
|
||||
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
|
||||
mlvl_anchors = self.anchor_generator.grid_anchors(
|
||||
featmap_sizes, device=device)
|
||||
|
||||
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
|
||||
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
|
||||
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(mlvl_anchors)
|
||||
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
batch_size = mlvl_cls_scores[0].shape[0]
|
||||
pre_topk = cfg.get('nms_pre', -1)
|
||||
|
||||
# loop over features, decode boxes
|
||||
mlvl_valid_bboxes = []
|
||||
mlvl_scores = []
|
||||
mlvl_valid_anchors = []
|
||||
for level_id, cls_score, bbox_pred, anchors in zip(
|
||||
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds, mlvl_anchors):
|
||||
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
|
||||
cls_score = cls_score.permute(0, 2, 3, 1)
|
||||
if self.use_sigmoid_cls:
|
||||
cls_score = cls_score.reshape(batch_size, -1)
|
||||
scores = cls_score.sigmoid()
|
||||
else:
|
||||
cls_score = cls_score.reshape(batch_size, -1, 2)
|
||||
# We set FG labels to [0, num_class-1] and BG label to
|
||||
# num_class in RPN head since mmdet v2.5, which is unified to
|
||||
# be consistent with other head since mmdet v2.0. In mmdet v2.0
|
||||
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
|
||||
scores = cls_score.softmax(-1)[..., 0]
|
||||
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
|
||||
|
||||
# use static anchor if input shape is static
|
||||
if not is_dynamic_flag:
|
||||
anchors = anchors.data
|
||||
|
||||
anchors = anchors.expand_as(bbox_pred)
|
||||
|
||||
enable_nms_pre = True
|
||||
backend = deploy_cfg['backend']
|
||||
# topk in tensorrt does not support shape<k
|
||||
# final level might meet the problem
|
||||
# TODO: support dynamic shape feature with TensorRT for topK op
|
||||
if backend == 'tensorrt':
|
||||
enable_nms_pre = (level_id != num_levels - 1)
|
||||
|
||||
if pre_topk > 0 and enable_nms_pre:
|
||||
_, topk_inds = scores.topk(pre_topk)
|
||||
batch_inds = torch.arange(
|
||||
batch_size, device=device).view(-1, 1).expand_as(topk_inds)
|
||||
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
|
||||
transformed_inds = scores.shape[1] * batch_inds + topk_inds
|
||||
scores = scores.reshape(-1, 1)[transformed_inds].reshape(
|
||||
batch_size, -1)
|
||||
bbox_pred = bbox_pred.reshape(-1, 4)[transformed_inds, :].reshape(
|
||||
batch_size, -1, 4)
|
||||
anchors = anchors.reshape(-1, 4)[transformed_inds, :].reshape(
|
||||
batch_size, -1, 4)
|
||||
mlvl_valid_bboxes.append(bbox_pred)
|
||||
mlvl_scores.append(scores)
|
||||
mlvl_valid_anchors.append(anchors)
|
||||
|
||||
batch_mlvl_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
|
||||
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1).unsqueeze(2)
|
||||
batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
|
||||
batch_mlvl_bboxes = self.bbox_coder.decode(
|
||||
batch_mlvl_anchors,
|
||||
batch_mlvl_bboxes,
|
||||
max_shape=img_metas['img_shape'])
|
||||
# ignore background class
|
||||
if not self.use_sigmoid_cls:
|
||||
batch_mlvl_scores = batch_mlvl_scores[..., :self.num_classes]
|
||||
if not with_nms:
|
||||
return batch_mlvl_bboxes, batch_mlvl_scores
|
||||
|
||||
post_params = deploy_cfg.post_processing
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||
pre_top_k = post_params.pre_top_k
|
||||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||
# only one class in rpn
|
||||
max_output_boxes_per_class = keep_top_k
|
||||
return multiclass_nms(
|
||||
batch_mlvl_bboxes,
|
||||
batch_mlvl_scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold=iou_threshold,
|
||||
score_threshold=score_threshold,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k)
|
|
@ -1,4 +1,9 @@
|
|||
from .single_stage import forward_of_single_stage
|
||||
from .base import forward_of_base_detector
|
||||
from .rpn import simple_test_of_rpn
|
||||
from .single_stage import simple_test_of_single_stage
|
||||
from .two_stage import extract_feat_of_two_stage
|
||||
|
||||
__all__ = ['forward_of_single_stage', 'extract_feat_of_two_stage']
|
||||
__all__ = [
|
||||
'simple_test_of_single_stage', 'extract_feat_of_two_stage',
|
||||
'forward_of_base_detector', 'simple_test_of_rpn'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.BaseDetector.forward')
|
||||
def forward_of_base_detector(ctx, self, img, img_metas=None, **kwargs):
|
||||
if img_metas is None:
|
||||
img_metas = {}
|
||||
assert isinstance(img_metas, dict)
|
||||
assert isinstance(img, torch.Tensor)
|
||||
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
# get origin input shape as tensor to support onnx dynamic shape
|
||||
img_shape = torch._shape_as_tensor(img)[2:]
|
||||
if not is_dynamic_flag:
|
||||
img_shape = [int(val) for val in img_shape]
|
||||
img_metas['img_shape'] = img_shape
|
||||
return self.simple_test(img, img_metas, **kwargs)
|
|
@ -0,0 +1,7 @@
|
|||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(func_name='mmdet.models.RPN.simple_test')
|
||||
def simple_test_of_rpn(ctx, self, img, img_metas, **kwargs):
|
||||
x = self.extract_feat(img)
|
||||
return self.rpn_head.simple_test_rpn(x, img_metas)
|
|
@ -1,26 +1,8 @@
|
|||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.SingleStageDetector.extract_feat')
|
||||
def extract_feat_of_single_stage(ctx, self, img):
|
||||
return ctx.origin_func(self, img)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(func_name='mmdet.models.RetinaNet.forward'
|
||||
)
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.SingleStageDetector.forward')
|
||||
def forward_of_single_stage(ctx, self, data, **kwargs):
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
# get origin input shape to support onnx dynamic shape
|
||||
img_shape = torch._shape_as_tensor(data)[2:]
|
||||
if not is_dynamic_flag:
|
||||
img_shape = [int(val) for val in img_shape]
|
||||
x = self.extract_feat(data)
|
||||
outs = self.bbox_head(x)
|
||||
return self.bbox_head.get_bboxes(*outs, img_shape, **kwargs)
|
||||
func_name='mmdet.models.SingleStageDetector.simple_test')
|
||||
def simple_test_of_single_stage(ctx, self, img, img_metas, **kwargs):
|
||||
feat = self.extract_feat(img)
|
||||
return self.bbox_head.simple_test(feat, img_metas, **kwargs)
|
||||
|
|
|
@ -6,3 +6,18 @@ from mmdeploy.core import FUNCTION_REWRITER, mark
|
|||
@mark('extract_feat', inputs='img', outputs='feat')
|
||||
def extract_feat_of_two_stage(ctx, self, img):
|
||||
return ctx.origin_func(self, img)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.TwoStageDetector.simple_test')
|
||||
def simple_test_of_two_stage(ctx,
|
||||
self,
|
||||
img,
|
||||
img_metas,
|
||||
proposals=None,
|
||||
**kwargs):
|
||||
assert self.with_bbox, 'Bbox head must be implemented.'
|
||||
x = self.extract_feat(img)
|
||||
if proposals is None:
|
||||
proposals, _ = self.rpn_head.simple_test_rpn(x, img_metas)
|
||||
return self.roi_head.simple_test(x, proposals, img_metas, rescale=False)
|
||||
|
|
|
@ -1 +1,4 @@
|
|||
from .bbox_heads import * # noqa: F401, F403
|
||||
from .roi_extractors import * # noqa: F401, F403
|
||||
from .standard_roi_head import * # noqa: F401, F403
|
||||
from .test_mixins import * # noqa: F401, F403
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from .bbox_head import get_bboxes_of_bbox_head
|
||||
|
||||
__all__ = ['get_bboxes_of_bbox_head']
|
|
@ -0,0 +1,62 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.mmdet.core import multiclass_nms
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.BBoxHead.get_bboxes')
|
||||
def get_bboxes_of_bbox_head(ctx, self, rois, cls_score, bbox_pred, img_shape,
|
||||
cfg, **kwargs):
|
||||
|
||||
assert rois.ndim == 3, 'Only support export two stage ' \
|
||||
'model to ONNX ' \
|
||||
'with batch dimension. '
|
||||
if self.custom_cls_channels:
|
||||
scores = self.loss_cls.get_activation(cls_score)
|
||||
else:
|
||||
scores = F.softmax(
|
||||
cls_score, dim=-1) if cls_score is not None else None
|
||||
|
||||
if bbox_pred is not None:
|
||||
bboxes = self.bbox_coder.decode(
|
||||
rois[..., 1:], bbox_pred, max_shape=img_shape)
|
||||
else:
|
||||
bboxes = rois[..., 1:].clone()
|
||||
if img_shape is not None:
|
||||
max_shape = bboxes.new_tensor(img_shape)[..., :2]
|
||||
min_xy = bboxes.new_tensor(0)
|
||||
max_xy = torch.cat([max_shape] * 2, dim=-1).flip(-1).unsqueeze(-2)
|
||||
bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
|
||||
bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
|
||||
|
||||
batch_size = scores.shape[0]
|
||||
device = scores.device
|
||||
# ignore background class
|
||||
scores = scores[..., :self.num_classes]
|
||||
if not self.reg_class_agnostic:
|
||||
# only keep boxes with the max scores
|
||||
max_inds = scores.reshape(-1, self.num_classes).argmax(1, keepdim=True)
|
||||
bboxes = bboxes.reshape(-1, self.num_classes, 4)
|
||||
dim0_inds = torch.arange(
|
||||
bboxes.shape[0], device=device).view(-1, 1).expand_as(max_inds)
|
||||
bboxes = bboxes[dim0_inds, max_inds].reshape(batch_size, -1, 4)
|
||||
|
||||
# get nms params
|
||||
post_params = ctx.cfg.post_processing
|
||||
max_output_boxes_per_class = post_params.max_output_boxes_per_class
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||
pre_top_k = post_params.pre_top_k
|
||||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||
dets, labels = multiclass_nms(
|
||||
bboxes,
|
||||
scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold=iou_threshold,
|
||||
score_threshold=score_threshold,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k)
|
||||
|
||||
return dets, labels
|
|
@ -1 +1,8 @@
|
|||
from .base_roi_extractor import * # noqa: F401, F403
|
||||
from .single_level_roi_extractor import (
|
||||
forward_of_single_roi_extractor_dynamic,
|
||||
forward_of_single_roi_extractor_static)
|
||||
|
||||
__all__ = [
|
||||
'forward_of_single_roi_extractor_dynamic',
|
||||
'forward_of_single_roi_extractor_static'
|
||||
]
|
||||
|
|
|
@ -51,11 +51,11 @@ class MultiLevelRoiAlign(Function):
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward',
|
||||
backend='tensorrt')
|
||||
def SingleRoIExtractor_forward_static(rewriter,
|
||||
self,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
def forward_of_single_roi_extractor_static(ctx,
|
||||
self,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
featmap_strides = self.featmap_strides
|
||||
finest_scale = self.finest_scale
|
||||
|
||||
|
@ -70,3 +70,31 @@ def SingleRoIExtractor_forward_static(rewriter,
|
|||
return MultiLevelRoiAlign.apply(*feats, rois, out_size, sampling_ratio,
|
||||
roi_scale_factor, finest_scale,
|
||||
featmap_strides, aligned)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward')
|
||||
def forward_of_single_roi_extractor_dynamic(ctx,
|
||||
self,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
out_size = self.roi_layers[0].output_size
|
||||
num_levels = len(feats)
|
||||
roi_feats = feats[0].new_zeros(rois.shape[0], self.out_channels, *out_size)
|
||||
if num_levels == 1:
|
||||
assert len(rois) > 0, 'The number of rois should be positive'
|
||||
return self.roi_layers[0](feats[0], rois)
|
||||
|
||||
target_lvls = self.map_roi_levels(rois, num_levels)
|
||||
|
||||
if roi_scale_factor is not None:
|
||||
rois = self.roi_rescale(rois, roi_scale_factor)
|
||||
|
||||
for i in range(num_levels):
|
||||
mask = target_lvls == i
|
||||
inds = mask.nonzero(as_tuple=False).squeeze(1)
|
||||
rois_i = rois[inds]
|
||||
roi_feats_t = self.roi_layers[i](feats[i], rois_i)
|
||||
roi_feats[inds] = roi_feats_t
|
||||
return roi_feats
|
|
@ -0,0 +1,16 @@
|
|||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.StandardRoIHead.simple_test')
|
||||
def simple_test_of_standard_roi_head(ctx, self, x, proposals, img_metas,
|
||||
**kwargs):
|
||||
assert self.with_bbox, 'Bbox head must be implemented.'
|
||||
det_bboxes, det_labels = self.simple_test_bboxes(
|
||||
x, img_metas, proposals, self.test_cfg, rescale=False)
|
||||
if not self.with_mask:
|
||||
return det_bboxes, det_labels
|
||||
|
||||
segm_results = self.simple_test_mask(
|
||||
x, img_metas, det_bboxes, det_labels, rescale=False)
|
||||
return det_bboxes, det_labels, segm_results
|
|
@ -0,0 +1,66 @@
|
|||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.test_mixins.\
|
||||
BBoxTestMixin.simple_test_bboxes')
|
||||
def simple_test_bboxes_of_bbox_test_mixin(ctx, self, x, img_metas, proposals,
|
||||
rcnn_test_cfg, **kwargs):
|
||||
rois = proposals
|
||||
batch_index = torch.arange(
|
||||
rois.shape[0], device=rois.device).float().view(-1, 1, 1).expand(
|
||||
rois.size(0), rois.size(1), 1)
|
||||
rois = torch.cat([batch_index, rois[..., :4]], dim=-1)
|
||||
batch_size = rois.shape[0]
|
||||
num_proposals_per_img = rois.shape[1]
|
||||
|
||||
# Eliminate the batch dimension
|
||||
rois = rois.view(-1, 5)
|
||||
bbox_results = self._bbox_forward(x, rois)
|
||||
cls_score = bbox_results['cls_score']
|
||||
bbox_pred = bbox_results['bbox_pred']
|
||||
|
||||
# Recover the batch dimension
|
||||
rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1))
|
||||
cls_score = cls_score.reshape(batch_size, num_proposals_per_img,
|
||||
cls_score.size(-1))
|
||||
|
||||
bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img,
|
||||
bbox_pred.size(-1))
|
||||
det_bboxes, det_labels = self.bbox_head.get_bboxes(
|
||||
rois, cls_score, bbox_pred, img_metas['img_shape'], cfg=rcnn_test_cfg)
|
||||
return det_bboxes, det_labels
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.test_mixins.\
|
||||
MaskTestMixin.simple_test_mask')
|
||||
def simple_test_mask_of_mask_test_mixin(ctx, self, x, img_metas, det_bboxes,
|
||||
det_labels, **kwargs):
|
||||
assert det_bboxes.shape[1] != 0, 'Can not record MaskHead as it \
|
||||
has not been executed this time'
|
||||
|
||||
batch_size = det_bboxes.size(0)
|
||||
# if det_bboxes is rescaled to the original image size, we need to
|
||||
# rescale it back to the testing scale to obtain RoIs.
|
||||
det_bboxes = det_bboxes[..., :4]
|
||||
batch_index = torch.arange(
|
||||
det_bboxes.size(0),
|
||||
device=det_bboxes.device).float().view(-1, 1, 1).expand(
|
||||
det_bboxes.size(0), det_bboxes.size(1), 1)
|
||||
mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
|
||||
mask_rois = mask_rois.view(-1, 5)
|
||||
mask_results = self._mask_forward(x, mask_rois)
|
||||
mask_pred = mask_results['mask_pred']
|
||||
max_shape = img_metas['img_shape']
|
||||
num_det = det_bboxes.shape[1]
|
||||
det_bboxes = det_bboxes.reshape(-1, 4)
|
||||
det_labels = det_labels.reshape(-1)
|
||||
segm_results = self.mask_head.get_seg_masks(mask_pred, det_bboxes,
|
||||
det_labels, self.test_cfg,
|
||||
max_shape)
|
||||
segm_results = segm_results.reshape(batch_size, num_det, max_shape[0],
|
||||
max_shape[1])
|
||||
return segm_results
|
|
@ -119,7 +119,8 @@ def main():
|
|||
device=args.device,
|
||||
output_file=f'output_{backend}.jpg',
|
||||
show_result=args.show,
|
||||
ret_value=ret_value))
|
||||
ret_value=ret_value),
|
||||
ret_value=ret_value)
|
||||
|
||||
# visualize pytorch model
|
||||
create_process(
|
||||
|
|
Loading…
Reference in New Issue