Support exporting to ONNX for Faster R-CNN (#20)

* update rewriter for rpn head

* new nms implement

* update bbox_head to support fasterrcnn with tensorrt

* rewrite SingleRoIExtractor

* remove unnecessary import

* fix fcos and bbox_nms

* resolve comments

* resolve comments

Co-authored-by: grimoire <yaoqian@sensetime.com>
pull/12/head
RunningLeon 2021-07-21 19:46:23 +08:00 committed by GitHub
parent b9e64f9e1c
commit 10a2385d01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 386 additions and 65 deletions

View File

@ -114,10 +114,7 @@ def init_backend_model(model_files: Sequence[str],
raise NotImplementedError(f'Unknown codebase type: {codebase}')
def get_classes_from_config(
codebase: str,
model_cfg: Union[str, mmcv.Config],
):
def get_classes_from_config(codebase: str, model_cfg: Union[str, mmcv.Config]):
model_cfg_str = model_cfg
if codebase == 'mmdet':
if module_exist(codebase):

View File

@ -15,7 +15,9 @@ class DummyONNXNMSop(torch.autograd.Function):
score_threshold):
batch_size, num_class, num_box = scores.shape
# create dummy indices of nms output
num_fake_det = 2
# number of detection should be large enough to
# cover all layers of fpn
num_fake_det = 100
batch_inds = torch.randint(batch_size, (num_fake_det, 1))
cls_inds = torch.randint(num_class, (num_fake_det, 1))
box_inds = torch.randint(num_box, (num_fake_det, 1))

View File

@ -47,8 +47,7 @@ def _multiclass_nms(boxes,
iou_threshold=0.5,
score_threshold=0.05,
pre_top_k=-1,
keep_top_k=-1,
labels=None):
keep_top_k=-1):
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
This function helps exporting to onnx with batch and multiclass NMS op.
@ -69,9 +68,6 @@ def _multiclass_nms(boxes,
Defaults to -1.
keep_top_k (int): Number of top K boxes to keep after nms.
Defaults to -1.
labels (Tensor, optional): It not None, explicit labels would be used.
Otherwise, labels would be automatically generated using
num_classed. Defaults to None.
Returns:
tuple[Tensor, Tensor]: dets of shape [N, num_det, 5] and class labels
@ -81,28 +77,20 @@ def _multiclass_nms(boxes,
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
batch_size = scores.shape[0]
num_class = scores.shape[2]
if pre_top_k > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.topk(pre_top_k)
batch_inds = torch.arange(batch_size).view(
-1, 1).expand_as(topk_inds).long()
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
boxes = boxes[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
if labels is not None:
labels = labels[batch_inds, topk_inds]
scores = scores.permute(0, 2, 1)
selected_indices = DummyONNXNMSop.apply(boxes, scores,
max_output_boxes_per_class,
iou_threshold, score_threshold)
if labels is None:
labels = torch.arange(num_class, dtype=torch.long).to(scores.device)
labels = labels.view(1, num_class, 1).expand_as(scores)
dets, labels = select_nms_index(
scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k)
return dets, labels
@ -118,8 +106,7 @@ def multiclass_nms_static(ctx,
iou_threshold=0.5,
score_threshold=0.05,
pre_top_k=-1,
keep_top_k=-1,
labels=None):
keep_top_k=-1):
boxes = boxes if boxes.dim() == 4 else boxes.unsqueeze(2)
keep_top_k = max_output_boxes_per_class if keep_top_k < 0 else min(
max_output_boxes_per_class, keep_top_k)

View File

@ -1,4 +1,8 @@
from .anchor_head import get_bboxes_of_anchor_head
from .fcos_head import get_bboxes_of_fcos_head
from .rpn_head import get_bboxes_of_rpn_head
__all__ = ['get_bboxes_of_anchor_head', 'get_bboxes_of_fcos_head']
__all__ = [
'get_bboxes_of_anchor_head', 'get_bboxes_of_fcos_head',
'get_bboxes_of_rpn_head'
]

View File

@ -7,13 +7,11 @@ from mmdeploy.utils import is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.AnchorHead.get_bboxes')
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.RetinaHead.get_bboxes')
def get_bboxes_of_anchor_head(ctx,
self,
cls_scores,
bbox_preds,
img_shape,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
@ -36,7 +34,8 @@ def get_bboxes_of_anchor_head(ctx,
pre_topk = cfg.get('nms_pre', -1)
# loop over features, decode boxes
mlvl_bboxes = []
mlvl_valid_bboxes = []
mlvl_valid_anchors = []
mlvl_scores = []
for level_id, cls_score, bbox_pred, anchors in zip(
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds, mlvl_anchors):
@ -74,20 +73,23 @@ def get_bboxes_of_anchor_head(ctx,
# BG cat_id: num_class
max_scores, _ = scores[..., :-1].max(-1)
_, topk_inds = max_scores.topk(pre_topk)
batch_inds = torch.arange(batch_size).view(-1,
1).expand_as(topk_inds)
batch_inds = torch.arange(
batch_size, device=device).view(-1, 1).expand_as(topk_inds)
anchors = anchors[batch_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
bboxes = self.bbox_coder.decode(
anchors, bbox_pred, max_shape=img_shape)
mlvl_bboxes.append(bboxes)
mlvl_valid_bboxes.append(bbox_pred)
mlvl_scores.append(scores)
mlvl_valid_anchors.append(anchors)
batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
batch_mlvl_valid_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
batch_mlvl_bboxes = self.bbox_coder.decode(
batch_mlvl_anchors,
batch_mlvl_valid_bboxes,
max_shape=img_metas['img_shape'])
# ignore background class
if not self.use_sigmoid_cls:
batch_mlvl_scores = batch_mlvl_scores[..., :self.num_classes]

View File

@ -12,7 +12,7 @@ def get_bboxes_of_fcos_head(ctx,
cls_scores,
bbox_preds,
centernesses,
img_shape,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
@ -87,7 +87,7 @@ def get_bboxes_of_fcos_head(ctx,
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
batch_mlvl_centerness = torch.cat(mlvl_centerness, dim=1)
batch_mlvl_bboxes = distance2bbox(
batch_mlvl_points, batch_mlvl_bboxes, max_shape=img_shape)
batch_mlvl_points, batch_mlvl_bboxes, max_shape=img_metas['img_shape'])
if not with_nms:
return batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness

View File

@ -0,0 +1,112 @@
import torch
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.mmdet.core import multiclass_nms
from mmdeploy.utils import is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter('mmdet.models.RPNHead.get_bboxes')
def get_bboxes_of_rpn_head(ctx,
self,
cls_scores,
bbox_preds,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
num_levels = len(cls_scores)
device = cls_scores[0].device
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device=device)
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(mlvl_anchors)
cfg = self.test_cfg if cfg is None else cfg
batch_size = mlvl_cls_scores[0].shape[0]
pre_topk = cfg.get('nms_pre', -1)
# loop over features, decode boxes
mlvl_valid_bboxes = []
mlvl_scores = []
mlvl_valid_anchors = []
for level_id, cls_score, bbox_pred, anchors in zip(
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds, mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(0, 2, 3, 1)
if self.use_sigmoid_cls:
cls_score = cls_score.reshape(batch_size, -1)
scores = cls_score.sigmoid()
else:
cls_score = cls_score.reshape(batch_size, -1, 2)
# We set FG labels to [0, num_class-1] and BG label to
# num_class in RPN head since mmdet v2.5, which is unified to
# be consistent with other head since mmdet v2.0. In mmdet v2.0
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
scores = cls_score.softmax(-1)[..., 0]
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
# use static anchor if input shape is static
if not is_dynamic_flag:
anchors = anchors.data
anchors = anchors.expand_as(bbox_pred)
enable_nms_pre = True
backend = deploy_cfg['backend']
# topk in tensorrt does not support shape<k
# final level might meet the problem
# TODO: support dynamic shape feature with TensorRT for topK op
if backend == 'tensorrt':
enable_nms_pre = (level_id != num_levels - 1)
if pre_topk > 0 and enable_nms_pre:
_, topk_inds = scores.topk(pre_topk)
batch_inds = torch.arange(
batch_size, device=device).view(-1, 1).expand_as(topk_inds)
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
transformed_inds = scores.shape[1] * batch_inds + topk_inds
scores = scores.reshape(-1, 1)[transformed_inds].reshape(
batch_size, -1)
bbox_pred = bbox_pred.reshape(-1, 4)[transformed_inds, :].reshape(
batch_size, -1, 4)
anchors = anchors.reshape(-1, 4)[transformed_inds, :].reshape(
batch_size, -1, 4)
mlvl_valid_bboxes.append(bbox_pred)
mlvl_scores.append(scores)
mlvl_valid_anchors.append(anchors)
batch_mlvl_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1).unsqueeze(2)
batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
batch_mlvl_bboxes = self.bbox_coder.decode(
batch_mlvl_anchors,
batch_mlvl_bboxes,
max_shape=img_metas['img_shape'])
# ignore background class
if not self.use_sigmoid_cls:
batch_mlvl_scores = batch_mlvl_scores[..., :self.num_classes]
if not with_nms:
return batch_mlvl_bboxes, batch_mlvl_scores
post_params = deploy_cfg.post_processing
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
# only one class in rpn
max_output_boxes_per_class = keep_top_k
return multiclass_nms(
batch_mlvl_bboxes,
batch_mlvl_scores,
max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)

View File

@ -1,4 +1,9 @@
from .single_stage import forward_of_single_stage
from .base import forward_of_base_detector
from .rpn import simple_test_of_rpn
from .single_stage import simple_test_of_single_stage
from .two_stage import extract_feat_of_two_stage
__all__ = ['forward_of_single_stage', 'extract_feat_of_two_stage']
__all__ = [
'simple_test_of_single_stage', 'extract_feat_of_two_stage',
'forward_of_base_detector', 'simple_test_of_rpn'
]

View File

@ -0,0 +1,22 @@
import torch
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.BaseDetector.forward')
def forward_of_base_detector(ctx, self, img, img_metas=None, **kwargs):
if img_metas is None:
img_metas = {}
assert isinstance(img_metas, dict)
assert isinstance(img, torch.Tensor)
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
# get origin input shape as tensor to support onnx dynamic shape
img_shape = torch._shape_as_tensor(img)[2:]
if not is_dynamic_flag:
img_shape = [int(val) for val in img_shape]
img_metas['img_shape'] = img_shape
return self.simple_test(img, img_metas, **kwargs)

View File

@ -0,0 +1,7 @@
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(func_name='mmdet.models.RPN.simple_test')
def simple_test_of_rpn(ctx, self, img, img_metas, **kwargs):
x = self.extract_feat(img)
return self.rpn_head.simple_test_rpn(x, img_metas)

View File

@ -1,26 +1,8 @@
import torch
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.SingleStageDetector.extract_feat')
def extract_feat_of_single_stage(ctx, self, img):
return ctx.origin_func(self, img)
@FUNCTION_REWRITER.register_rewriter(func_name='mmdet.models.RetinaNet.forward'
)
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.SingleStageDetector.forward')
def forward_of_single_stage(ctx, self, data, **kwargs):
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
# get origin input shape to support onnx dynamic shape
img_shape = torch._shape_as_tensor(data)[2:]
if not is_dynamic_flag:
img_shape = [int(val) for val in img_shape]
x = self.extract_feat(data)
outs = self.bbox_head(x)
return self.bbox_head.get_bboxes(*outs, img_shape, **kwargs)
func_name='mmdet.models.SingleStageDetector.simple_test')
def simple_test_of_single_stage(ctx, self, img, img_metas, **kwargs):
feat = self.extract_feat(img)
return self.bbox_head.simple_test(feat, img_metas, **kwargs)

View File

@ -6,3 +6,18 @@ from mmdeploy.core import FUNCTION_REWRITER, mark
@mark('extract_feat', inputs='img', outputs='feat')
def extract_feat_of_two_stage(ctx, self, img):
return ctx.origin_func(self, img)
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.TwoStageDetector.simple_test')
def simple_test_of_two_stage(ctx,
self,
img,
img_metas,
proposals=None,
**kwargs):
assert self.with_bbox, 'Bbox head must be implemented.'
x = self.extract_feat(img)
if proposals is None:
proposals, _ = self.rpn_head.simple_test_rpn(x, img_metas)
return self.roi_head.simple_test(x, proposals, img_metas, rescale=False)

View File

@ -1 +1,4 @@
from .bbox_heads import * # noqa: F401, F403
from .roi_extractors import * # noqa: F401, F403
from .standard_roi_head import * # noqa: F401, F403
from .test_mixins import * # noqa: F401, F403

View File

@ -0,0 +1,3 @@
from .bbox_head import get_bboxes_of_bbox_head
__all__ = ['get_bboxes_of_bbox_head']

View File

@ -0,0 +1,62 @@
import torch
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.mmdet.core import multiclass_nms
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.BBoxHead.get_bboxes')
def get_bboxes_of_bbox_head(ctx, self, rois, cls_score, bbox_pred, img_shape,
cfg, **kwargs):
assert rois.ndim == 3, 'Only support export two stage ' \
'model to ONNX ' \
'with batch dimension. '
if self.custom_cls_channels:
scores = self.loss_cls.get_activation(cls_score)
else:
scores = F.softmax(
cls_score, dim=-1) if cls_score is not None else None
if bbox_pred is not None:
bboxes = self.bbox_coder.decode(
rois[..., 1:], bbox_pred, max_shape=img_shape)
else:
bboxes = rois[..., 1:].clone()
if img_shape is not None:
max_shape = bboxes.new_tensor(img_shape)[..., :2]
min_xy = bboxes.new_tensor(0)
max_xy = torch.cat([max_shape] * 2, dim=-1).flip(-1).unsqueeze(-2)
bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
batch_size = scores.shape[0]
device = scores.device
# ignore background class
scores = scores[..., :self.num_classes]
if not self.reg_class_agnostic:
# only keep boxes with the max scores
max_inds = scores.reshape(-1, self.num_classes).argmax(1, keepdim=True)
bboxes = bboxes.reshape(-1, self.num_classes, 4)
dim0_inds = torch.arange(
bboxes.shape[0], device=device).view(-1, 1).expand_as(max_inds)
bboxes = bboxes[dim0_inds, max_inds].reshape(batch_size, -1, 4)
# get nms params
post_params = ctx.cfg.post_processing
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
dets, labels = multiclass_nms(
bboxes,
scores,
max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)
return dets, labels

View File

@ -1 +1,8 @@
from .base_roi_extractor import * # noqa: F401, F403
from .single_level_roi_extractor import (
forward_of_single_roi_extractor_dynamic,
forward_of_single_roi_extractor_static)
__all__ = [
'forward_of_single_roi_extractor_dynamic',
'forward_of_single_roi_extractor_static'
]

View File

@ -51,11 +51,11 @@ class MultiLevelRoiAlign(Function):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward',
backend='tensorrt')
def SingleRoIExtractor_forward_static(rewriter,
self,
feats,
rois,
roi_scale_factor=None):
def forward_of_single_roi_extractor_static(ctx,
self,
feats,
rois,
roi_scale_factor=None):
featmap_strides = self.featmap_strides
finest_scale = self.finest_scale
@ -70,3 +70,31 @@ def SingleRoIExtractor_forward_static(rewriter,
return MultiLevelRoiAlign.apply(*feats, rois, out_size, sampling_ratio,
roi_scale_factor, finest_scale,
featmap_strides, aligned)
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward')
def forward_of_single_roi_extractor_dynamic(ctx,
self,
feats,
rois,
roi_scale_factor=None):
out_size = self.roi_layers[0].output_size
num_levels = len(feats)
roi_feats = feats[0].new_zeros(rois.shape[0], self.out_channels, *out_size)
if num_levels == 1:
assert len(rois) > 0, 'The number of rois should be positive'
return self.roi_layers[0](feats[0], rois)
target_lvls = self.map_roi_levels(rois, num_levels)
if roi_scale_factor is not None:
rois = self.roi_rescale(rois, roi_scale_factor)
for i in range(num_levels):
mask = target_lvls == i
inds = mask.nonzero(as_tuple=False).squeeze(1)
rois_i = rois[inds]
roi_feats_t = self.roi_layers[i](feats[i], rois_i)
roi_feats[inds] = roi_feats_t
return roi_feats

View File

@ -0,0 +1,16 @@
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.StandardRoIHead.simple_test')
def simple_test_of_standard_roi_head(ctx, self, x, proposals, img_metas,
**kwargs):
assert self.with_bbox, 'Bbox head must be implemented.'
det_bboxes, det_labels = self.simple_test_bboxes(
x, img_metas, proposals, self.test_cfg, rescale=False)
if not self.with_mask:
return det_bboxes, det_labels
segm_results = self.simple_test_mask(
x, img_metas, det_bboxes, det_labels, rescale=False)
return det_bboxes, det_labels, segm_results

View File

@ -0,0 +1,66 @@
import torch
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.test_mixins.\
BBoxTestMixin.simple_test_bboxes')
def simple_test_bboxes_of_bbox_test_mixin(ctx, self, x, img_metas, proposals,
rcnn_test_cfg, **kwargs):
rois = proposals
batch_index = torch.arange(
rois.shape[0], device=rois.device).float().view(-1, 1, 1).expand(
rois.size(0), rois.size(1), 1)
rois = torch.cat([batch_index, rois[..., :4]], dim=-1)
batch_size = rois.shape[0]
num_proposals_per_img = rois.shape[1]
# Eliminate the batch dimension
rois = rois.view(-1, 5)
bbox_results = self._bbox_forward(x, rois)
cls_score = bbox_results['cls_score']
bbox_pred = bbox_results['bbox_pred']
# Recover the batch dimension
rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1))
cls_score = cls_score.reshape(batch_size, num_proposals_per_img,
cls_score.size(-1))
bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img,
bbox_pred.size(-1))
det_bboxes, det_labels = self.bbox_head.get_bboxes(
rois, cls_score, bbox_pred, img_metas['img_shape'], cfg=rcnn_test_cfg)
return det_bboxes, det_labels
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.test_mixins.\
MaskTestMixin.simple_test_mask')
def simple_test_mask_of_mask_test_mixin(ctx, self, x, img_metas, det_bboxes,
det_labels, **kwargs):
assert det_bboxes.shape[1] != 0, 'Can not record MaskHead as it \
has not been executed this time'
batch_size = det_bboxes.size(0)
# if det_bboxes is rescaled to the original image size, we need to
# rescale it back to the testing scale to obtain RoIs.
det_bboxes = det_bboxes[..., :4]
batch_index = torch.arange(
det_bboxes.size(0),
device=det_bboxes.device).float().view(-1, 1, 1).expand(
det_bboxes.size(0), det_bboxes.size(1), 1)
mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
mask_rois = mask_rois.view(-1, 5)
mask_results = self._mask_forward(x, mask_rois)
mask_pred = mask_results['mask_pred']
max_shape = img_metas['img_shape']
num_det = det_bboxes.shape[1]
det_bboxes = det_bboxes.reshape(-1, 4)
det_labels = det_labels.reshape(-1)
segm_results = self.mask_head.get_seg_masks(mask_pred, det_bboxes,
det_labels, self.test_cfg,
max_shape)
segm_results = segm_results.reshape(batch_size, num_det, max_shape[0],
max_shape[1])
return segm_results

View File

@ -119,7 +119,8 @@ def main():
device=args.device,
output_file=f'output_{backend}.jpg',
show_result=args.show,
ret_value=ret_value))
ret_value=ret_value),
ret_value=ret_value)
# visualize pytorch model
create_process(