[Feat] Add end2end pointpillars & centerpoint(pillar) deployment for mmdet3d (#1178)

* add end2end pointpillars & centerpoint(pillar)

* fix centerpoint UT

* uncomment pycuda

* add nvidia copyright and remove post_process for voxel_detection_model

* keep comments

* add anchor3d_head init

* remove pycuda comment

* add pcd test sample
pull/1366/head
Jiahao Sun 2022-11-24 16:14:50 +08:00 committed by GitHub
parent 301035a06f
commit 9bbe3c0355
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 385 additions and 190 deletions

View File

@ -3,4 +3,4 @@ codebase_config = dict(
type='mmdet3d', task='VoxelDetection', model_type='end2end')
onnx_config = dict(
input_names=['voxels', 'num_points', 'coors'],
output_names=['scores', 'bbox_preds', 'dir_scores'])
output_names=['bboxes', 'scores', 'labels'])

View File

@ -1,3 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import anchor # noqa: F401,F403
from . import bbox # noqa: F401,F403
from . import post_processing # noqa: F401,F403

View File

@ -0,0 +1,2 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .anchor_3d_generator import * # noqa: F401,F403

View File

@ -0,0 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.core.anchor.anchor_3d_generator.AlignedAnchor3DRangeGenerator.'
'anchors_single_range')
def alignedanchor3drangegenerator__anchors_single_range(
ctx,
self,
feature_size,
anchor_range,
scale,
sizes=[[3.9, 1.6, 1.56]],
rotations=[0, 1.5707963],
device='cuda'):
"""Generate anchors in a single range. Rewrite this func for default
backend.
Args:
feature_size (list[float] | tuple[float]): Feature map size. It is
either a list of a tuple of [D, H, W](in order of z, y, and x).
anchor_range (torch.Tensor | list[float]): Range of anchors with
shape [6]. The order is consistent with that of anchors, i.e.,
(x_min, y_min, z_min, x_max, y_max, z_max).
scale (float | int): The scale factor of anchors.
sizes (list[list] | np.ndarray | torch.Tensor, optional):
Anchor size with shape [N, 3], in order of x, y, z.
Defaults to [[3.9, 1.6, 1.56]].
rotations (list[float] | np.ndarray | torch.Tensor, optional):
Rotations of anchors in a single feature grid.
Defaults to [0, 1.5707963].
device (str, optional): Devices that the anchors will be put on.
Defaults to 'cuda'.
Returns:
torch.Tensor: Anchors with shape
[*feature_size, num_sizes, num_rots, 7].
"""
if len(feature_size) == 2:
feature_size = [1, feature_size[0], feature_size[1]]
anchor_range = torch.tensor(anchor_range, device=device)
z_centers = torch.arange(feature_size[0], device=device)
z_centers = z_centers.to(anchor_range.dtype)
y_centers = torch.arange(feature_size[1], device=device)
y_centers = y_centers.to(anchor_range.dtype)
x_centers = torch.arange(feature_size[2], device=device)
x_centers = x_centers.to(anchor_range.dtype)
# shift the anchor center
if not self.align_corner:
z_centers += 0.5
y_centers += 0.5
x_centers += 0.5
z_centers = z_centers / feature_size[0] * (
anchor_range[5] - anchor_range[2]) + anchor_range[2]
y_centers = y_centers / feature_size[1] * (
anchor_range[4] - anchor_range[1]) + anchor_range[1]
x_centers = x_centers / feature_size[2] * (
anchor_range[3] - anchor_range[0]) + anchor_range[0]
sizes = torch.tensor(sizes, device=device).reshape(-1, 3) * scale
rotations = torch.tensor(rotations, device=device)
# torch.meshgrid default behavior is 'id', np's default is 'xy'
rets = torch.meshgrid(x_centers, y_centers, z_centers, rotations)
# torch.meshgrid returns a tuple rather than list
rets = list(rets)
tile_shape = [1] * 5
tile_shape[-2] = int(sizes.shape[0])
for i in range(len(rets)):
rets[i] = rets[i].unsqueeze(-2).repeat(tile_shape).unsqueeze(-1)
sizes = sizes.reshape([1, 1, 1, -1, 1, 3])
tile_size_shape = list(rets[0].shape)
tile_size_shape[3] = 1
sizes = sizes.repeat(tile_size_shape)
rets.insert(3, sizes)
ret = torch.cat(rets, dim=-1).permute([2, 1, 0, 3, 4, 5])
if len(self.custom_values) > 0:
custom_ndim = len(self.custom_values)
custom = ret.new_zeros([*ret.shape[:-1], custom_ndim])
# TODO: check the support of custom values
# custom[:] = self.custom_values
ret = torch.cat([ret, custom], dim=-1)
return ret

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import centerpoint_bbox_coders # noqa: F401,F403
from . import fcos3d_bbox_coder # noqa: F401,F403
from .utils import points_img2cam

View File

@ -0,0 +1,108 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.core.bbox.coders.centerpoint_bbox_coders.CenterPointBBoxCoder.'
'decode')
def centerpointbboxcoder__decode(ctx,
self,
heat,
rot_sine,
rot_cosine,
hei,
dim,
vel,
reg=None,
task_id=-1):
"""Decode bboxes. Rewrite this func for default backend.
Args:
heat (torch.Tensor): Heatmap with the shape of [B, N, W, H].
rot_sine (torch.Tensor): Sine of rotation with the shape of
[B, 1, W, H].
rot_cosine (torch.Tensor): Cosine of rotation with the shape of
[B, 1, W, H].
hei (torch.Tensor): Height of the boxes with the shape
of [B, 1, W, H].
dim (torch.Tensor): Dim of the boxes with the shape of
[B, 1, W, H].
vel (torch.Tensor): Velocity with the shape of [B, 1, W, H].
reg (torch.Tensor, optional): Regression value of the boxes in
2D with the shape of [B, 2, W, H]. Default: None.
task_id (int, optional): Index of task. Default: -1.
Returns:
list[dict]: Decoded boxes.
"""
batch, cat, _, _ = heat.size()
scores, inds, clses, ys, xs = self._topk(heat, K=self.max_num)
if reg is not None:
reg = self._transpose_and_gather_feat(reg, inds)
reg = reg.view(batch, self.max_num, 2)
xs = xs.view(batch, self.max_num, 1) + reg[:, :, 0:1]
ys = ys.view(batch, self.max_num, 1) + reg[:, :, 1:2]
else:
xs = xs.view(batch, self.max_num, 1) + 0.5
ys = ys.view(batch, self.max_num, 1) + 0.5
# rotation value and direction label
rot_sine = self._transpose_and_gather_feat(rot_sine, inds)
rot_sine = rot_sine.view(batch, self.max_num, 1)
rot_cosine = self._transpose_and_gather_feat(rot_cosine, inds)
rot_cosine = rot_cosine.view(batch, self.max_num, 1)
rot = torch.atan2(rot_sine, rot_cosine)
# height in the bev
hei = self._transpose_and_gather_feat(hei, inds)
hei = hei.view(batch, self.max_num, 1)
# dim of the box
dim = self._transpose_and_gather_feat(dim, inds)
dim = dim.view(batch, self.max_num, 3)
# class label
clses = clses.view(batch, self.max_num).float()
scores = scores.view(batch, self.max_num)
xs = xs.view(
batch, self.max_num,
1) * self.out_size_factor * self.voxel_size[0] + self.pc_range[0]
ys = ys.view(
batch, self.max_num,
1) * self.out_size_factor * self.voxel_size[1] + self.pc_range[1]
if vel is None: # KITTI FORMAT
final_box_preds = torch.cat([xs, ys, hei, dim, rot], dim=2)
else: # exist velocity, nuscene format
vel = self._transpose_and_gather_feat(vel, inds)
vel = vel.view(batch, self.max_num, 2)
final_box_preds = torch.cat([xs, ys, hei, dim, rot, vel], dim=2)
final_scores = scores
final_preds = clses
self.post_center_range = torch.tensor(
self.post_center_range, device=heat.device)
range_mask = torch.prod(
torch.cat((final_box_preds[..., :3] >= self.post_center_range[:3],
final_box_preds[..., :3] <= self.post_center_range[3:]),
dim=-1),
dim=-1).bool()
final_box_preds = torch.where(
range_mask.unsqueeze(-1), final_box_preds,
torch.zeros(1, device=heat.device))
final_scores = torch.where(range_mask, final_scores,
torch.zeros(1, device=heat.device))
final_preds = torch.where(range_mask, final_preds,
torch.zeros(1, device=heat.device))
predictions_dict = {
'bboxes': final_box_preds[0],
'scores': final_scores[0],
'labels': final_preds[0],
}
return [predictions_dict]

View File

@ -163,8 +163,8 @@ def _box3d_multiclass_nms(
dir_scores, attr_scores)
# @FUNCTION_REWRITER.register_rewriter(
# func_name='mmdet3d.core.post_processing.box3d_multiclass_nms')
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet3d.core.post_processing.box3d_multiclass_nms')
def box3d_multiclass_nms(*args, **kwargs):
"""Wrapper function for `_box3d_multiclass_nms`."""
return mmdeploy.codebase.mmdet3d.core.post_processing.box3d_nms.\

View File

@ -95,6 +95,8 @@ class VoxelDetection(BaseTask):
score_thr (float): The score threshold to display the bbox.
Defaults to 0.3.
"""
if output_file.endswith('.jpg'):
output_file = output_file.split('.')[0]
from mmdet3d.apis import show_result_meshlab
data = VoxelDetection.read_pcd_file(image, self.model_cfg, self.device)
show_result_meshlab(

View File

@ -7,9 +7,8 @@ from mmcv.utils import Registry
from torch.nn import functional as F
from mmdeploy.codebase.base import BaseBackendModel
from mmdeploy.core import RewriterContext
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
get_root_logger, load_config)
load_config)
def __build_backend_voxel_model(cls_name: str, registry: Registry, *args,
@ -83,7 +82,7 @@ class VoxelDetectionModel(BaseBackendModel):
Returns:
list: A list contains predictions.
"""
result_list = []
results = []
for i in range(len(img_metas)):
voxels, num_points, coors = VoxelDetectionModel.voxelize(
self.model_cfg, points[i])
@ -93,12 +92,15 @@ class VoxelDetectionModel(BaseBackendModel):
'coors': coors
}
outputs = self.wrapper(input_dict)
result = VoxelDetectionModel.post_process(self.model_cfg,
self.deploy_cfg, outputs,
img_metas[i],
self.device)[0]
result_list.append(result)
return result_list
outputs = self.wrapper.output_to_list(outputs)
outputs = [x.squeeze(0) for x in outputs]
bbox_dim = outputs[0].shape[-1]
outputs[0] = img_metas[0][0]['box_type_3d'](outputs[0], bbox_dim)
from mmdet3d.core import bbox3d2result
result = bbox3d2result(*outputs)
results.append(result)
return results
def show_result(self,
data: Dict,
@ -171,66 +173,6 @@ class VoxelDetectionModel(BaseBackendModel):
coors_batch = torch.cat(coors_batch, dim=0)
return voxels, num_points, coors_batch
@staticmethod
def post_process(model_cfg: Union[str, mmcv.Config],
deploy_cfg: Union[str, mmcv.Config],
outs: Dict,
img_metas: Dict,
device: str,
rescale=False):
"""model post process.
Args:
model_cfg (str | mmcv.Config): The model config.
deploy_cfg (str|mmcv.Config): Deployment config file or loaded
Config object.
outs (Dict): Output of model's head.
img_metas(Dict): Meta info for pcd.
device (str): A string specifying device type.
rescale (list[torch.Tensor]): whether th rescale bbox.
Returns:
list: A list contains predictions, include bboxes, scores, labels.
"""
from mmdet3d.core import bbox3d2result
from mmdet3d.models.builder import build_head
model_cfg = load_config(model_cfg)[0]
deploy_cfg = load_config(deploy_cfg)[0]
if 'bbox_head' in model_cfg.model.keys():
head_cfg = dict(**model_cfg.model['bbox_head'])
elif 'pts_bbox_head' in model_cfg.model.keys():
head_cfg = dict(**model_cfg.model['pts_bbox_head'])
else:
raise NotImplementedError('Not supported model.')
head_cfg['train_cfg'] = None
head_cfg['test_cfg'] = model_cfg.model['test_cfg']\
if 'pts' not in model_cfg.model['test_cfg'].keys()\
else model_cfg.model['test_cfg']['pts']
head = build_head(head_cfg)
if device == 'cpu':
logger = get_root_logger()
logger.warning(
'Don\'t suggest using CPU device. Post process can\'t support.'
)
if torch.cuda.is_available():
device = 'cuda'
else:
raise NotImplementedError(
'Post process don\'t support device=cpu')
cls_scores = [outs['scores'].to(device)]
bbox_preds = [outs['bbox_preds'].to(device)]
dir_scores = [outs['dir_scores'].to(device)]
with RewriterContext(
cfg=deploy_cfg,
backend=deploy_cfg.backend_config.type,
opset=deploy_cfg.onnx_config.opset_version):
bbox_list = head.get_bboxes(
cls_scores, bbox_preds, dir_scores, img_metas, rescale=False)
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results
def build_voxel_detection_model(model_files: Sequence[str],
model_cfg: Union[str, mmcv.Config],

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import anchor3d_head # noqa: F401,F403
from . import anchor_free_mono3d_head # noqa: F401,F403
from . import base # noqa: F401,F403
from . import centerpoint # noqa: F401,F403

View File

@ -0,0 +1,107 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmdet3d.core.bbox.structures import limit_period
from mmdeploy.codebase.mmdet3d.core.post_processing import box3d_multiclass_nms
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.models.dense_heads.anchor3d_head.Anchor3DHead.'
'get_bboxes')
def anchor3dhead__get_bboxes(ctx,
self,
cls_scores,
bbox_preds,
dir_cls_preds,
input_metas,
cfg=None,
rescale=False):
"""Rewrite `get_bboxes` of `Anchor3DHead` for default backend.
Args:
ctx (ContextCaller): The context with additional information.
self (FoveaHead): The instance of the class FoveaHead.
cls_scores (list[Tensor]): Box scores for each scale level
with shape (N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 7, H, W).
dir_cls_preds (list[Tensor]): Direction predicts for
all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W).
input_metas (list[dict]): Meta information of the image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
rescale (bool): If True, return boxes in original image space.
Default: False.
Returns:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (bboxes, scores, labels),
`bboxes` of shape [N, num_det, 7] ,`scores` of shape
[N, num_det] and `labels` of shape [N, num_det].
"""
assert len(cls_scores) == len(bbox_preds) == len(dir_cls_preds)
num_levels = len(cls_scores)
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
device = cls_scores[0].device
mlvl_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device=device)
mlvl_anchors = [
anchor.reshape(-1, self.box_code_size) for anchor in mlvl_anchors
]
cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
dir_cls_preds = [dir_cls_preds[i].detach() for i in range(num_levels)]
cfg = self.test_cfg if cfg is None else cfg
mlvl_bboxes = []
mlvl_scores = []
mlvl_dir_scores = []
for cls_score, bbox_pred, dir_cls_pred, anchors in zip(
cls_scores, bbox_preds, dir_cls_preds, mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
dir_cls_pred = dir_cls_pred.permute(0, 2, 3, 1).reshape(1, -1, 2)
dir_cls_score = torch.max(dir_cls_pred, dim=-1)[1]
cls_score = cls_score.permute(0, 2, 3,
1).reshape(1, -1, self.num_classes)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(1, -1, self.box_code_size)
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[1] > nms_pre:
if self.use_sigmoid_cls:
max_scores, _ = scores.max(dim=2)
else:
max_scores, _ = scores[..., :-1].max(dim=2)
max_scores = max_scores[0]
_, topk_inds = max_scores.topk(nms_pre)
anchors = anchors[topk_inds, :]
bbox_pred = bbox_pred[:, topk_inds, :]
scores = scores[:, topk_inds, :]
dir_cls_score = dir_cls_score[:, topk_inds]
bboxes = self.bbox_coder.decode(anchors, bbox_pred)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_dir_scores.append(dir_cls_score)
mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
mlvl_bboxes_for_nms = mlvl_bboxes[..., [0, 1, 3, 4, 6]].clone()
mlvl_scores = torch.cat(mlvl_scores, dim=1)
mlvl_dir_scores = torch.cat(mlvl_dir_scores, dim=1)
if mlvl_bboxes.shape[0] > 0:
dir_rot = limit_period(mlvl_bboxes[..., 6] - self.dir_offset,
self.dir_limit_offset, np.pi)
mlvl_bboxes[..., 6] = (
dir_rot + self.dir_offset +
np.pi * mlvl_dir_scores.to(mlvl_bboxes.dtype))
return box3d_multiclass_nms(mlvl_bboxes, mlvl_bboxes_for_nms, mlvl_scores,
cfg.score_thr, cfg.nms_thr, cfg.max_num)

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet3d.core import circle_nms
from mmdeploy.codebase.mmdet3d.core.post_processing import box3d_multiclass_nms
from mmdeploy.core import FUNCTION_REWRITER
@ -48,37 +48,22 @@ def centerpoint__simple_test_pts(ctx, self, x, img_metas, rescale=False):
List: Result of model.
"""
outs = self.pts_bbox_head(x)
bbox_preds, scores, dir_scores = [], [], []
for task_res in outs:
bbox_preds.append(task_res[0]['reg'])
bbox_preds.append(task_res[0]['height'])
bbox_preds.append(task_res[0]['dim'])
if 'vel' in task_res[0].keys():
bbox_preds.append(task_res[0]['vel'])
scores.append(task_res[0]['heatmap'])
dir_scores.append(task_res[0]['rot'])
bbox_preds = torch.cat(bbox_preds, dim=1)
scores = torch.cat(scores, dim=1)
dir_scores = torch.cat(dir_scores, dim=1)
return scores, bbox_preds, dir_scores
bbox_list = self.pts_bbox_head.get_bboxes(outs, img_metas, rescale=rescale)
return bbox_list
@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.models.dense_heads.centerpoint_head.CenterHead.get_bboxes')
def centerpoint__get_bbox(ctx,
self,
cls_scores,
bbox_preds,
dir_scores,
preds_dicts,
img_metas,
img=None,
rescale=False):
"""Rewrite this func to format func inputs.
Args
cls_scores (list[torch.Tensor]): Classification predicts results.
bbox_preds (list[torch.Tensor]): Bbox predicts results.
dir_scores (list[torch.Tensor]): Dir predicts results.
pred_dicts (list[dict]): Each task predicts results.
img_metas (list[dict]): Point cloud and image's meta info.
img (torch.Tensor): Input image.
rescale (Bool): Whether need rescale.
@ -87,44 +72,24 @@ def centerpoint__get_bbox(ctx,
list[dict]: Decoded bbox, scores and labels after nms.
"""
rets = []
scores_range = [0]
bbox_range = [0]
dir_range = [0]
for i, task_head in enumerate(self.task_heads):
scores_range.append(scores_range[i] + self.num_classes[i])
bbox_range.append(bbox_range[i] + 8)
dir_range.append(dir_range[i] + 2)
for task_id in range(len(self.num_classes)):
num_class_with_bg = self.num_classes[task_id]
for task_id, preds_dict in enumerate(preds_dicts):
batch_heatmap = preds_dict[0]['heatmap'].sigmoid()
batch_heatmap = cls_scores[
0][:, scores_range[task_id]:scores_range[task_id + 1],
...].sigmoid()
batch_reg = bbox_preds[0][:,
bbox_range[task_id]:bbox_range[task_id] + 2,
...]
batch_hei = bbox_preds[0][:, bbox_range[task_id] +
2:bbox_range[task_id] + 3, ...]
batch_reg = preds_dict[0]['reg']
batch_hei = preds_dict[0]['height']
if self.norm_bbox:
batch_dim = torch.exp(bbox_preds[0][:, bbox_range[task_id] +
3:bbox_range[task_id] + 6,
...])
batch_dim = torch.exp(preds_dict[0]['dim'])
else:
batch_dim = bbox_preds[0][:, bbox_range[task_id] +
3:bbox_range[task_id] + 6, ...]
batch_dim = preds_dict[0]['dim']
batch_vel = bbox_preds[0][:, bbox_range[task_id] +
6:bbox_range[task_id + 1], ...]
batch_rots = dir_scores[0][:,
dir_range[task_id]:dir_range[task_id + 1],
...][:, 0].unsqueeze(1)
batch_rotc = dir_scores[0][:,
dir_range[task_id]:dir_range[task_id + 1],
...][:, 1].unsqueeze(1)
batch_rots = preds_dict[0]['rot'][:, 0].unsqueeze(1)
batch_rotc = preds_dict[0]['rot'][:, 1].unsqueeze(1)
if 'vel' in preds_dict[0]:
batch_vel = preds_dict[0]['vel']
else:
batch_vel = None
temp = self.bbox_coder.decode(
batch_heatmap,
batch_rots,
@ -134,57 +99,32 @@ def centerpoint__get_bbox(ctx,
batch_vel,
reg=batch_reg,
task_id=task_id)
if 'pts' in self.test_cfg.keys():
self.test_cfg = self.test_cfg.pts
assert self.test_cfg['nms_type'] in ['circle', 'rotate']
batch_reg_preds = [box['bboxes'] for box in temp]
batch_cls_preds = [box['scores'] for box in temp]
batch_cls_labels = [box['labels'] for box in temp]
batch_bboxes = temp[0]['bboxes'].unsqueeze(0)
batch_scores = temp[0]['scores'].unsqueeze(0).unsqueeze(-1)
batch_cls_labels = temp[0]['labels'].unsqueeze(0).long()
batch_bboxes_for_nms = batch_bboxes[..., [0, 1, 3, 4, 6]].clone()
if self.test_cfg['nms_type'] == 'circle':
boxes3d = temp[0]['bboxes']
scores = temp[0]['scores']
labels = temp[0]['labels']
centers = boxes3d[:, [0, 1]]
boxes = torch.cat([centers, scores.view(-1, 1)], dim=1)
keep = torch.tensor(
circle_nms(
boxes.detach().cpu().numpy(),
self.test_cfg['min_radius'][task_id],
post_max_size=self.test_cfg['post_max_size']),
dtype=torch.long,
device=boxes.device)
boxes3d = boxes3d[keep]
scores = scores[keep]
labels = labels[keep]
ret = dict(bboxes=boxes3d, scores=scores, labels=labels)
ret_task = [ret]
rets.append(ret_task)
raise NotImplementedError(
'Not implement circle nms for deployment now!')
else:
rets.append(
self.get_task_detections(num_class_with_bg, batch_cls_preds,
batch_reg_preds, batch_cls_labels,
img_metas))
box3d_multiclass_nms(batch_bboxes, batch_bboxes_for_nms,
batch_scores,
self.test_cfg['score_threshold'],
self.test_cfg['nms_thr'],
self.test_cfg['post_max_size'], None,
batch_cls_labels))
# Merge branches results
num_samples = len(rets[0])
bboxes = torch.cat([ret[0] for ret in rets], dim=1)
bboxes[..., 2] = bboxes[..., 2] - bboxes[..., 5] * 0.5
scores = torch.cat([ret[1] for ret in rets], dim=1)
ret_list = []
for i in range(num_samples):
for k in rets[0][i].keys():
if k == 'bboxes':
bboxes = torch.cat([ret[i][k] for ret in rets])
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
bboxes = img_metas[i]['box_type_3d'](bboxes,
self.bbox_coder.code_size)
elif k == 'scores':
scores = torch.cat([ret[i][k] for ret in rets])
elif k == 'labels':
flag = 0
for j, num_class in enumerate(self.num_classes):
rets[j][i][k] += flag
flag += num_class
labels = torch.cat([ret[i][k].int() for ret in rets])
ret_list.append([bboxes, scores, labels])
return ret_list
labels = [ret[3] for ret in rets]
flag = 0
for i, num_class in enumerate(self.num_classes):
labels[i] += flag
flag += num_class
labels = torch.cat(labels, dim=1)
return bboxes, scores, labels

View File

@ -102,5 +102,6 @@ def mvxtwostagedetector__simple_test_pts(ctx,
Returns:
List: Result of model.
"""
bbox_preds, scores, dir_scores = self.pts_bbox_head(x)
return bbox_preds, scores, dir_scores
outs = self.pts_bbox_head(x)
outs = self.pts_bbox_head.get_bboxes(*outs, img_metas, rescale=rescale)
return outs

View File

@ -30,8 +30,7 @@ def pointpillarsscatter__forward(ctx,
indices = indices.long()
voxels = voxel_features.t()
# Now scatter the blob back to the canvas.
canvas.scatter_(
dim=1, index=indices.expand(canvas.shape[0], -1), src=voxels)
canvas[:, indices] = voxels
# Undo the column stacking to final 4-dim tensor
canvas = canvas.view(1, self.in_channels, self.ny, self.nx)
return canvas

View File

@ -25,8 +25,9 @@ def voxelnet__simple_test(ctx,
List: Result of model.
"""
x = self.extract_feat(voxels, num_points, coors, img_metas)
bbox_preds, scores, dir_scores = self.bbox_head(x)
return bbox_preds, scores, dir_scores
outs = self.bbox_head(x)
outs = self.bbox_head.get_bboxes(*outs, img_metas, rescale=rescale)
return outs
@FUNCTION_REWRITER.register_rewriter(

View File

@ -170,6 +170,7 @@ class TRTBatchedBEVNMSop(torch.autograd.Function):
background_label_id: int = -1,
return_index: bool = True):
"""Forward of batched nms.
Args:
ctx (Context): The context with meta information.
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].

View File

@ -224,7 +224,7 @@ centerpoint_model = dict(
score_threshold=0.1,
out_size_factor=4,
voxel_size=voxel_size[:2],
nms_type='circle',
nms_type='rotate',
pre_max_size=1000,
post_max_size=83,
nms_thr=0.2)))

View File

@ -155,10 +155,7 @@ def test_centerpoint(backend_type: Backend):
cfg=deploy_cfg,
backend=deploy_cfg.backend_config.type,
opset=deploy_cfg.onnx_config.opset_version):
outputs = model.forward(*data)
head = get_centerpoint_head()
rewrite_outputs = head.get_bboxes(*[[i] for i in outputs],
inputs['img_metas'][0])
rewrite_outputs = model.forward(*data)
assert rewrite_outputs is not None

View File

@ -34,7 +34,7 @@ deploy_cfg = mmcv.Config(
opset_version=11,
input_shape=None,
input_names=['voxels', 'num_points', 'coors'],
output_names=['scores', 'bbox_preds', 'dir_scores'])))
output_names=['bboxes', 'scores', 'labels'])))
onnx_file = NamedTemporaryFile(suffix='.onnx').name
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
@ -52,9 +52,9 @@ def backend_model():
wrapper = SwitchBackendWrapper(ORTWrapper)
wrapper.set(
outputs={
'scores': torch.rand(1, 18, 32, 32),
'bbox_preds': torch.rand(1, 42, 32, 32),
'dir_scores': torch.rand(1, 12, 32, 32)
'bboxes': torch.rand(1, 50, 7),
'scores': torch.rand(1, 50),
'labels': torch.rand(1, 50)
})
yield task_processor.init_backend_model([''])

View File

@ -33,15 +33,15 @@ class TestVoxelDetectionModel:
# simplify backend inference
cls.wrapper = SwitchBackendWrapper(ORTWrapper)
cls.outputs = {
'scores': torch.rand(1, 18, 32, 32),
'bbox_preds': torch.rand(1, 42, 32, 32),
'dir_scores': torch.rand(1, 12, 32, 32)
'bboxes': torch.rand(1, 50, 7),
'scores': torch.rand(1, 50),
'labels': torch.rand(1, 50)
}
cls.wrapper.set(outputs=cls.outputs)
deploy_cfg = mmcv.Config({
'onnx_config': {
'input_names': ['voxels', 'num_points', 'coors'],
'output_names': ['scores', 'bbox_preds', 'dir_scores'],
'output_names': ['bboxes', 'scores', 'labels'],
'opset_version': 11
},
'backend_config': {