[Feat] Add end2end pointpillars & centerpoint(pillar) deployment for mmdet3d (#1178)
* add end2end pointpillars & centerpoint(pillar) * fix centerpoint UT * uncomment pycuda * add nvidia copyright and remove post_process for voxel_detection_model * keep comments * add anchor3d_head init * remove pycuda comment * add pcd test samplepull/1366/head
parent
301035a06f
commit
9bbe3c0355
|
@ -3,4 +3,4 @@ codebase_config = dict(
|
|||
type='mmdet3d', task='VoxelDetection', model_type='end2end')
|
||||
onnx_config = dict(
|
||||
input_names=['voxels', 'num_points', 'coors'],
|
||||
output_names=['scores', 'bbox_preds', 'dir_scores'])
|
||||
output_names=['bboxes', 'scores', 'labels'])
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import anchor # noqa: F401,F403
|
||||
from . import bbox # noqa: F401,F403
|
||||
from . import post_processing # noqa: F401,F403
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .anchor_3d_generator import * # noqa: F401,F403
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet3d.core.anchor.anchor_3d_generator.AlignedAnchor3DRangeGenerator.'
|
||||
'anchors_single_range')
|
||||
def alignedanchor3drangegenerator__anchors_single_range(
|
||||
ctx,
|
||||
self,
|
||||
feature_size,
|
||||
anchor_range,
|
||||
scale,
|
||||
sizes=[[3.9, 1.6, 1.56]],
|
||||
rotations=[0, 1.5707963],
|
||||
device='cuda'):
|
||||
"""Generate anchors in a single range. Rewrite this func for default
|
||||
backend.
|
||||
|
||||
Args:
|
||||
feature_size (list[float] | tuple[float]): Feature map size. It is
|
||||
either a list of a tuple of [D, H, W](in order of z, y, and x).
|
||||
anchor_range (torch.Tensor | list[float]): Range of anchors with
|
||||
shape [6]. The order is consistent with that of anchors, i.e.,
|
||||
(x_min, y_min, z_min, x_max, y_max, z_max).
|
||||
scale (float | int): The scale factor of anchors.
|
||||
sizes (list[list] | np.ndarray | torch.Tensor, optional):
|
||||
Anchor size with shape [N, 3], in order of x, y, z.
|
||||
Defaults to [[3.9, 1.6, 1.56]].
|
||||
rotations (list[float] | np.ndarray | torch.Tensor, optional):
|
||||
Rotations of anchors in a single feature grid.
|
||||
Defaults to [0, 1.5707963].
|
||||
device (str, optional): Devices that the anchors will be put on.
|
||||
Defaults to 'cuda'.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Anchors with shape
|
||||
[*feature_size, num_sizes, num_rots, 7].
|
||||
"""
|
||||
if len(feature_size) == 2:
|
||||
feature_size = [1, feature_size[0], feature_size[1]]
|
||||
anchor_range = torch.tensor(anchor_range, device=device)
|
||||
z_centers = torch.arange(feature_size[0], device=device)
|
||||
z_centers = z_centers.to(anchor_range.dtype)
|
||||
y_centers = torch.arange(feature_size[1], device=device)
|
||||
y_centers = y_centers.to(anchor_range.dtype)
|
||||
x_centers = torch.arange(feature_size[2], device=device)
|
||||
x_centers = x_centers.to(anchor_range.dtype)
|
||||
|
||||
# shift the anchor center
|
||||
if not self.align_corner:
|
||||
z_centers += 0.5
|
||||
y_centers += 0.5
|
||||
x_centers += 0.5
|
||||
|
||||
z_centers = z_centers / feature_size[0] * (
|
||||
anchor_range[5] - anchor_range[2]) + anchor_range[2]
|
||||
y_centers = y_centers / feature_size[1] * (
|
||||
anchor_range[4] - anchor_range[1]) + anchor_range[1]
|
||||
x_centers = x_centers / feature_size[2] * (
|
||||
anchor_range[3] - anchor_range[0]) + anchor_range[0]
|
||||
|
||||
sizes = torch.tensor(sizes, device=device).reshape(-1, 3) * scale
|
||||
rotations = torch.tensor(rotations, device=device)
|
||||
|
||||
# torch.meshgrid default behavior is 'id', np's default is 'xy'
|
||||
rets = torch.meshgrid(x_centers, y_centers, z_centers, rotations)
|
||||
|
||||
# torch.meshgrid returns a tuple rather than list
|
||||
rets = list(rets)
|
||||
tile_shape = [1] * 5
|
||||
tile_shape[-2] = int(sizes.shape[0])
|
||||
for i in range(len(rets)):
|
||||
rets[i] = rets[i].unsqueeze(-2).repeat(tile_shape).unsqueeze(-1)
|
||||
|
||||
sizes = sizes.reshape([1, 1, 1, -1, 1, 3])
|
||||
tile_size_shape = list(rets[0].shape)
|
||||
tile_size_shape[3] = 1
|
||||
sizes = sizes.repeat(tile_size_shape)
|
||||
rets.insert(3, sizes)
|
||||
|
||||
ret = torch.cat(rets, dim=-1).permute([2, 1, 0, 3, 4, 5])
|
||||
|
||||
if len(self.custom_values) > 0:
|
||||
custom_ndim = len(self.custom_values)
|
||||
custom = ret.new_zeros([*ret.shape[:-1], custom_ndim])
|
||||
# TODO: check the support of custom values
|
||||
# custom[:] = self.custom_values
|
||||
ret = torch.cat([ret, custom], dim=-1)
|
||||
return ret
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import centerpoint_bbox_coders # noqa: F401,F403
|
||||
from . import fcos3d_bbox_coder # noqa: F401,F403
|
||||
from .utils import points_img2cam
|
||||
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet3d.core.bbox.coders.centerpoint_bbox_coders.CenterPointBBoxCoder.'
|
||||
'decode')
|
||||
def centerpointbboxcoder__decode(ctx,
|
||||
self,
|
||||
heat,
|
||||
rot_sine,
|
||||
rot_cosine,
|
||||
hei,
|
||||
dim,
|
||||
vel,
|
||||
reg=None,
|
||||
task_id=-1):
|
||||
"""Decode bboxes. Rewrite this func for default backend.
|
||||
|
||||
Args:
|
||||
heat (torch.Tensor): Heatmap with the shape of [B, N, W, H].
|
||||
rot_sine (torch.Tensor): Sine of rotation with the shape of
|
||||
[B, 1, W, H].
|
||||
rot_cosine (torch.Tensor): Cosine of rotation with the shape of
|
||||
[B, 1, W, H].
|
||||
hei (torch.Tensor): Height of the boxes with the shape
|
||||
of [B, 1, W, H].
|
||||
dim (torch.Tensor): Dim of the boxes with the shape of
|
||||
[B, 1, W, H].
|
||||
vel (torch.Tensor): Velocity with the shape of [B, 1, W, H].
|
||||
reg (torch.Tensor, optional): Regression value of the boxes in
|
||||
2D with the shape of [B, 2, W, H]. Default: None.
|
||||
task_id (int, optional): Index of task. Default: -1.
|
||||
|
||||
Returns:
|
||||
list[dict]: Decoded boxes.
|
||||
"""
|
||||
batch, cat, _, _ = heat.size()
|
||||
|
||||
scores, inds, clses, ys, xs = self._topk(heat, K=self.max_num)
|
||||
|
||||
if reg is not None:
|
||||
reg = self._transpose_and_gather_feat(reg, inds)
|
||||
reg = reg.view(batch, self.max_num, 2)
|
||||
xs = xs.view(batch, self.max_num, 1) + reg[:, :, 0:1]
|
||||
ys = ys.view(batch, self.max_num, 1) + reg[:, :, 1:2]
|
||||
else:
|
||||
xs = xs.view(batch, self.max_num, 1) + 0.5
|
||||
ys = ys.view(batch, self.max_num, 1) + 0.5
|
||||
|
||||
# rotation value and direction label
|
||||
rot_sine = self._transpose_and_gather_feat(rot_sine, inds)
|
||||
rot_sine = rot_sine.view(batch, self.max_num, 1)
|
||||
|
||||
rot_cosine = self._transpose_and_gather_feat(rot_cosine, inds)
|
||||
rot_cosine = rot_cosine.view(batch, self.max_num, 1)
|
||||
rot = torch.atan2(rot_sine, rot_cosine)
|
||||
|
||||
# height in the bev
|
||||
hei = self._transpose_and_gather_feat(hei, inds)
|
||||
hei = hei.view(batch, self.max_num, 1)
|
||||
|
||||
# dim of the box
|
||||
dim = self._transpose_and_gather_feat(dim, inds)
|
||||
dim = dim.view(batch, self.max_num, 3)
|
||||
|
||||
# class label
|
||||
clses = clses.view(batch, self.max_num).float()
|
||||
scores = scores.view(batch, self.max_num)
|
||||
|
||||
xs = xs.view(
|
||||
batch, self.max_num,
|
||||
1) * self.out_size_factor * self.voxel_size[0] + self.pc_range[0]
|
||||
ys = ys.view(
|
||||
batch, self.max_num,
|
||||
1) * self.out_size_factor * self.voxel_size[1] + self.pc_range[1]
|
||||
|
||||
if vel is None: # KITTI FORMAT
|
||||
final_box_preds = torch.cat([xs, ys, hei, dim, rot], dim=2)
|
||||
else: # exist velocity, nuscene format
|
||||
vel = self._transpose_and_gather_feat(vel, inds)
|
||||
vel = vel.view(batch, self.max_num, 2)
|
||||
final_box_preds = torch.cat([xs, ys, hei, dim, rot, vel], dim=2)
|
||||
|
||||
final_scores = scores
|
||||
final_preds = clses
|
||||
self.post_center_range = torch.tensor(
|
||||
self.post_center_range, device=heat.device)
|
||||
range_mask = torch.prod(
|
||||
torch.cat((final_box_preds[..., :3] >= self.post_center_range[:3],
|
||||
final_box_preds[..., :3] <= self.post_center_range[3:]),
|
||||
dim=-1),
|
||||
dim=-1).bool()
|
||||
final_box_preds = torch.where(
|
||||
range_mask.unsqueeze(-1), final_box_preds,
|
||||
torch.zeros(1, device=heat.device))
|
||||
final_scores = torch.where(range_mask, final_scores,
|
||||
torch.zeros(1, device=heat.device))
|
||||
final_preds = torch.where(range_mask, final_preds,
|
||||
torch.zeros(1, device=heat.device))
|
||||
predictions_dict = {
|
||||
'bboxes': final_box_preds[0],
|
||||
'scores': final_scores[0],
|
||||
'labels': final_preds[0],
|
||||
}
|
||||
return [predictions_dict]
|
|
@ -163,8 +163,8 @@ def _box3d_multiclass_nms(
|
|||
dir_scores, attr_scores)
|
||||
|
||||
|
||||
# @FUNCTION_REWRITER.register_rewriter(
|
||||
# func_name='mmdet3d.core.post_processing.box3d_multiclass_nms')
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet3d.core.post_processing.box3d_multiclass_nms')
|
||||
def box3d_multiclass_nms(*args, **kwargs):
|
||||
"""Wrapper function for `_box3d_multiclass_nms`."""
|
||||
return mmdeploy.codebase.mmdet3d.core.post_processing.box3d_nms.\
|
||||
|
|
|
@ -95,6 +95,8 @@ class VoxelDetection(BaseTask):
|
|||
score_thr (float): The score threshold to display the bbox.
|
||||
Defaults to 0.3.
|
||||
"""
|
||||
if output_file.endswith('.jpg'):
|
||||
output_file = output_file.split('.')[0]
|
||||
from mmdet3d.apis import show_result_meshlab
|
||||
data = VoxelDetection.read_pcd_file(image, self.model_cfg, self.device)
|
||||
show_result_meshlab(
|
||||
|
|
|
@ -7,9 +7,8 @@ from mmcv.utils import Registry
|
|||
from torch.nn import functional as F
|
||||
|
||||
from mmdeploy.codebase.base import BaseBackendModel
|
||||
from mmdeploy.core import RewriterContext
|
||||
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
|
||||
get_root_logger, load_config)
|
||||
load_config)
|
||||
|
||||
|
||||
def __build_backend_voxel_model(cls_name: str, registry: Registry, *args,
|
||||
|
@ -83,7 +82,7 @@ class VoxelDetectionModel(BaseBackendModel):
|
|||
Returns:
|
||||
list: A list contains predictions.
|
||||
"""
|
||||
result_list = []
|
||||
results = []
|
||||
for i in range(len(img_metas)):
|
||||
voxels, num_points, coors = VoxelDetectionModel.voxelize(
|
||||
self.model_cfg, points[i])
|
||||
|
@ -93,12 +92,15 @@ class VoxelDetectionModel(BaseBackendModel):
|
|||
'coors': coors
|
||||
}
|
||||
outputs = self.wrapper(input_dict)
|
||||
result = VoxelDetectionModel.post_process(self.model_cfg,
|
||||
self.deploy_cfg, outputs,
|
||||
img_metas[i],
|
||||
self.device)[0]
|
||||
result_list.append(result)
|
||||
return result_list
|
||||
outputs = self.wrapper.output_to_list(outputs)
|
||||
outputs = [x.squeeze(0) for x in outputs]
|
||||
bbox_dim = outputs[0].shape[-1]
|
||||
outputs[0] = img_metas[0][0]['box_type_3d'](outputs[0], bbox_dim)
|
||||
from mmdet3d.core import bbox3d2result
|
||||
|
||||
result = bbox3d2result(*outputs)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def show_result(self,
|
||||
data: Dict,
|
||||
|
@ -171,66 +173,6 @@ class VoxelDetectionModel(BaseBackendModel):
|
|||
coors_batch = torch.cat(coors_batch, dim=0)
|
||||
return voxels, num_points, coors_batch
|
||||
|
||||
@staticmethod
|
||||
def post_process(model_cfg: Union[str, mmcv.Config],
|
||||
deploy_cfg: Union[str, mmcv.Config],
|
||||
outs: Dict,
|
||||
img_metas: Dict,
|
||||
device: str,
|
||||
rescale=False):
|
||||
"""model post process.
|
||||
|
||||
Args:
|
||||
model_cfg (str | mmcv.Config): The model config.
|
||||
deploy_cfg (str|mmcv.Config): Deployment config file or loaded
|
||||
Config object.
|
||||
outs (Dict): Output of model's head.
|
||||
img_metas(Dict): Meta info for pcd.
|
||||
device (str): A string specifying device type.
|
||||
rescale (list[torch.Tensor]): whether th rescale bbox.
|
||||
Returns:
|
||||
list: A list contains predictions, include bboxes, scores, labels.
|
||||
"""
|
||||
from mmdet3d.core import bbox3d2result
|
||||
from mmdet3d.models.builder import build_head
|
||||
model_cfg = load_config(model_cfg)[0]
|
||||
deploy_cfg = load_config(deploy_cfg)[0]
|
||||
if 'bbox_head' in model_cfg.model.keys():
|
||||
head_cfg = dict(**model_cfg.model['bbox_head'])
|
||||
elif 'pts_bbox_head' in model_cfg.model.keys():
|
||||
head_cfg = dict(**model_cfg.model['pts_bbox_head'])
|
||||
else:
|
||||
raise NotImplementedError('Not supported model.')
|
||||
head_cfg['train_cfg'] = None
|
||||
head_cfg['test_cfg'] = model_cfg.model['test_cfg']\
|
||||
if 'pts' not in model_cfg.model['test_cfg'].keys()\
|
||||
else model_cfg.model['test_cfg']['pts']
|
||||
head = build_head(head_cfg)
|
||||
if device == 'cpu':
|
||||
logger = get_root_logger()
|
||||
logger.warning(
|
||||
'Don\'t suggest using CPU device. Post process can\'t support.'
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
device = 'cuda'
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Post process don\'t support device=cpu')
|
||||
cls_scores = [outs['scores'].to(device)]
|
||||
bbox_preds = [outs['bbox_preds'].to(device)]
|
||||
dir_scores = [outs['dir_scores'].to(device)]
|
||||
with RewriterContext(
|
||||
cfg=deploy_cfg,
|
||||
backend=deploy_cfg.backend_config.type,
|
||||
opset=deploy_cfg.onnx_config.opset_version):
|
||||
bbox_list = head.get_bboxes(
|
||||
cls_scores, bbox_preds, dir_scores, img_metas, rescale=False)
|
||||
bbox_results = [
|
||||
bbox3d2result(bboxes, scores, labels)
|
||||
for bboxes, scores, labels in bbox_list
|
||||
]
|
||||
return bbox_results
|
||||
|
||||
|
||||
def build_voxel_detection_model(model_files: Sequence[str],
|
||||
model_cfg: Union[str, mmcv.Config],
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import anchor3d_head # noqa: F401,F403
|
||||
from . import anchor_free_mono3d_head # noqa: F401,F403
|
||||
from . import base # noqa: F401,F403
|
||||
from . import centerpoint # noqa: F401,F403
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmdet3d.core.bbox.structures import limit_period
|
||||
|
||||
from mmdeploy.codebase.mmdet3d.core.post_processing import box3d_multiclass_nms
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet3d.models.dense_heads.anchor3d_head.Anchor3DHead.'
|
||||
'get_bboxes')
|
||||
def anchor3dhead__get_bboxes(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
dir_cls_preds,
|
||||
input_metas,
|
||||
cfg=None,
|
||||
rescale=False):
|
||||
"""Rewrite `get_bboxes` of `Anchor3DHead` for default backend.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self (FoveaHead): The instance of the class FoveaHead.
|
||||
cls_scores (list[Tensor]): Box scores for each scale level
|
||||
with shape (N, num_anchors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for each scale
|
||||
level with shape (N, num_anchors * 7, H, W).
|
||||
dir_cls_preds (list[Tensor]): Direction predicts for
|
||||
all scale level, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * 1, H, W).
|
||||
input_metas (list[dict]): Meta information of the image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used. Default: None.
|
||||
rescale (bool): If True, return boxes in original image space.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (bboxes, scores, labels),
|
||||
`bboxes` of shape [N, num_det, 7] ,`scores` of shape
|
||||
[N, num_det] and `labels` of shape [N, num_det].
|
||||
"""
|
||||
assert len(cls_scores) == len(bbox_preds) == len(dir_cls_preds)
|
||||
num_levels = len(cls_scores)
|
||||
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
|
||||
device = cls_scores[0].device
|
||||
mlvl_anchors = self.anchor_generator.grid_anchors(
|
||||
featmap_sizes, device=device)
|
||||
mlvl_anchors = [
|
||||
anchor.reshape(-1, self.box_code_size) for anchor in mlvl_anchors
|
||||
]
|
||||
|
||||
cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
|
||||
bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
|
||||
dir_cls_preds = [dir_cls_preds[i].detach() for i in range(num_levels)]
|
||||
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
mlvl_bboxes = []
|
||||
mlvl_scores = []
|
||||
mlvl_dir_scores = []
|
||||
for cls_score, bbox_pred, dir_cls_pred, anchors in zip(
|
||||
cls_scores, bbox_preds, dir_cls_preds, mlvl_anchors):
|
||||
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
|
||||
dir_cls_pred = dir_cls_pred.permute(0, 2, 3, 1).reshape(1, -1, 2)
|
||||
dir_cls_score = torch.max(dir_cls_pred, dim=-1)[1]
|
||||
|
||||
cls_score = cls_score.permute(0, 2, 3,
|
||||
1).reshape(1, -1, self.num_classes)
|
||||
if self.use_sigmoid_cls:
|
||||
scores = cls_score.sigmoid()
|
||||
else:
|
||||
scores = cls_score.softmax(-1)
|
||||
bbox_pred = bbox_pred.permute(0, 2, 3,
|
||||
1).reshape(1, -1, self.box_code_size)
|
||||
|
||||
nms_pre = cfg.get('nms_pre', -1)
|
||||
if nms_pre > 0 and scores.shape[1] > nms_pre:
|
||||
if self.use_sigmoid_cls:
|
||||
max_scores, _ = scores.max(dim=2)
|
||||
else:
|
||||
max_scores, _ = scores[..., :-1].max(dim=2)
|
||||
max_scores = max_scores[0]
|
||||
_, topk_inds = max_scores.topk(nms_pre)
|
||||
anchors = anchors[topk_inds, :]
|
||||
bbox_pred = bbox_pred[:, topk_inds, :]
|
||||
scores = scores[:, topk_inds, :]
|
||||
dir_cls_score = dir_cls_score[:, topk_inds]
|
||||
|
||||
bboxes = self.bbox_coder.decode(anchors, bbox_pred)
|
||||
mlvl_bboxes.append(bboxes)
|
||||
mlvl_scores.append(scores)
|
||||
mlvl_dir_scores.append(dir_cls_score)
|
||||
|
||||
mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
|
||||
mlvl_bboxes_for_nms = mlvl_bboxes[..., [0, 1, 3, 4, 6]].clone()
|
||||
mlvl_scores = torch.cat(mlvl_scores, dim=1)
|
||||
mlvl_dir_scores = torch.cat(mlvl_dir_scores, dim=1)
|
||||
if mlvl_bboxes.shape[0] > 0:
|
||||
dir_rot = limit_period(mlvl_bboxes[..., 6] - self.dir_offset,
|
||||
self.dir_limit_offset, np.pi)
|
||||
mlvl_bboxes[..., 6] = (
|
||||
dir_rot + self.dir_offset +
|
||||
np.pi * mlvl_dir_scores.to(mlvl_bboxes.dtype))
|
||||
return box3d_multiclass_nms(mlvl_bboxes, mlvl_bboxes_for_nms, mlvl_scores,
|
||||
cfg.score_thr, cfg.nms_thr, cfg.max_num)
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmdet3d.core import circle_nms
|
||||
|
||||
from mmdeploy.codebase.mmdet3d.core.post_processing import box3d_multiclass_nms
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
|
@ -48,37 +48,22 @@ def centerpoint__simple_test_pts(ctx, self, x, img_metas, rescale=False):
|
|||
List: Result of model.
|
||||
"""
|
||||
outs = self.pts_bbox_head(x)
|
||||
bbox_preds, scores, dir_scores = [], [], []
|
||||
for task_res in outs:
|
||||
bbox_preds.append(task_res[0]['reg'])
|
||||
bbox_preds.append(task_res[0]['height'])
|
||||
bbox_preds.append(task_res[0]['dim'])
|
||||
if 'vel' in task_res[0].keys():
|
||||
bbox_preds.append(task_res[0]['vel'])
|
||||
scores.append(task_res[0]['heatmap'])
|
||||
dir_scores.append(task_res[0]['rot'])
|
||||
bbox_preds = torch.cat(bbox_preds, dim=1)
|
||||
scores = torch.cat(scores, dim=1)
|
||||
dir_scores = torch.cat(dir_scores, dim=1)
|
||||
return scores, bbox_preds, dir_scores
|
||||
bbox_list = self.pts_bbox_head.get_bboxes(outs, img_metas, rescale=rescale)
|
||||
return bbox_list
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet3d.models.dense_heads.centerpoint_head.CenterHead.get_bboxes')
|
||||
def centerpoint__get_bbox(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
dir_scores,
|
||||
preds_dicts,
|
||||
img_metas,
|
||||
img=None,
|
||||
rescale=False):
|
||||
"""Rewrite this func to format func inputs.
|
||||
|
||||
Args
|
||||
cls_scores (list[torch.Tensor]): Classification predicts results.
|
||||
bbox_preds (list[torch.Tensor]): Bbox predicts results.
|
||||
dir_scores (list[torch.Tensor]): Dir predicts results.
|
||||
pred_dicts (list[dict]): Each task predicts results.
|
||||
img_metas (list[dict]): Point cloud and image's meta info.
|
||||
img (torch.Tensor): Input image.
|
||||
rescale (Bool): Whether need rescale.
|
||||
|
@ -87,44 +72,24 @@ def centerpoint__get_bbox(ctx,
|
|||
list[dict]: Decoded bbox, scores and labels after nms.
|
||||
"""
|
||||
rets = []
|
||||
scores_range = [0]
|
||||
bbox_range = [0]
|
||||
dir_range = [0]
|
||||
for i, task_head in enumerate(self.task_heads):
|
||||
scores_range.append(scores_range[i] + self.num_classes[i])
|
||||
bbox_range.append(bbox_range[i] + 8)
|
||||
dir_range.append(dir_range[i] + 2)
|
||||
for task_id in range(len(self.num_classes)):
|
||||
num_class_with_bg = self.num_classes[task_id]
|
||||
for task_id, preds_dict in enumerate(preds_dicts):
|
||||
batch_heatmap = preds_dict[0]['heatmap'].sigmoid()
|
||||
|
||||
batch_heatmap = cls_scores[
|
||||
0][:, scores_range[task_id]:scores_range[task_id + 1],
|
||||
...].sigmoid()
|
||||
|
||||
batch_reg = bbox_preds[0][:,
|
||||
bbox_range[task_id]:bbox_range[task_id] + 2,
|
||||
...]
|
||||
batch_hei = bbox_preds[0][:, bbox_range[task_id] +
|
||||
2:bbox_range[task_id] + 3, ...]
|
||||
batch_reg = preds_dict[0]['reg']
|
||||
batch_hei = preds_dict[0]['height']
|
||||
|
||||
if self.norm_bbox:
|
||||
batch_dim = torch.exp(bbox_preds[0][:, bbox_range[task_id] +
|
||||
3:bbox_range[task_id] + 6,
|
||||
...])
|
||||
batch_dim = torch.exp(preds_dict[0]['dim'])
|
||||
else:
|
||||
batch_dim = bbox_preds[0][:, bbox_range[task_id] +
|
||||
3:bbox_range[task_id] + 6, ...]
|
||||
batch_dim = preds_dict[0]['dim']
|
||||
|
||||
batch_vel = bbox_preds[0][:, bbox_range[task_id] +
|
||||
6:bbox_range[task_id + 1], ...]
|
||||
|
||||
batch_rots = dir_scores[0][:,
|
||||
dir_range[task_id]:dir_range[task_id + 1],
|
||||
...][:, 0].unsqueeze(1)
|
||||
batch_rotc = dir_scores[0][:,
|
||||
dir_range[task_id]:dir_range[task_id + 1],
|
||||
...][:, 1].unsqueeze(1)
|
||||
batch_rots = preds_dict[0]['rot'][:, 0].unsqueeze(1)
|
||||
batch_rotc = preds_dict[0]['rot'][:, 1].unsqueeze(1)
|
||||
|
||||
if 'vel' in preds_dict[0]:
|
||||
batch_vel = preds_dict[0]['vel']
|
||||
else:
|
||||
batch_vel = None
|
||||
temp = self.bbox_coder.decode(
|
||||
batch_heatmap,
|
||||
batch_rots,
|
||||
|
@ -134,57 +99,32 @@ def centerpoint__get_bbox(ctx,
|
|||
batch_vel,
|
||||
reg=batch_reg,
|
||||
task_id=task_id)
|
||||
if 'pts' in self.test_cfg.keys():
|
||||
self.test_cfg = self.test_cfg.pts
|
||||
assert self.test_cfg['nms_type'] in ['circle', 'rotate']
|
||||
batch_reg_preds = [box['bboxes'] for box in temp]
|
||||
batch_cls_preds = [box['scores'] for box in temp]
|
||||
batch_cls_labels = [box['labels'] for box in temp]
|
||||
batch_bboxes = temp[0]['bboxes'].unsqueeze(0)
|
||||
batch_scores = temp[0]['scores'].unsqueeze(0).unsqueeze(-1)
|
||||
batch_cls_labels = temp[0]['labels'].unsqueeze(0).long()
|
||||
batch_bboxes_for_nms = batch_bboxes[..., [0, 1, 3, 4, 6]].clone()
|
||||
if self.test_cfg['nms_type'] == 'circle':
|
||||
|
||||
boxes3d = temp[0]['bboxes']
|
||||
scores = temp[0]['scores']
|
||||
labels = temp[0]['labels']
|
||||
centers = boxes3d[:, [0, 1]]
|
||||
boxes = torch.cat([centers, scores.view(-1, 1)], dim=1)
|
||||
keep = torch.tensor(
|
||||
circle_nms(
|
||||
boxes.detach().cpu().numpy(),
|
||||
self.test_cfg['min_radius'][task_id],
|
||||
post_max_size=self.test_cfg['post_max_size']),
|
||||
dtype=torch.long,
|
||||
device=boxes.device)
|
||||
|
||||
boxes3d = boxes3d[keep]
|
||||
scores = scores[keep]
|
||||
labels = labels[keep]
|
||||
ret = dict(bboxes=boxes3d, scores=scores, labels=labels)
|
||||
ret_task = [ret]
|
||||
rets.append(ret_task)
|
||||
raise NotImplementedError(
|
||||
'Not implement circle nms for deployment now!')
|
||||
else:
|
||||
rets.append(
|
||||
self.get_task_detections(num_class_with_bg, batch_cls_preds,
|
||||
batch_reg_preds, batch_cls_labels,
|
||||
img_metas))
|
||||
box3d_multiclass_nms(batch_bboxes, batch_bboxes_for_nms,
|
||||
batch_scores,
|
||||
self.test_cfg['score_threshold'],
|
||||
self.test_cfg['nms_thr'],
|
||||
self.test_cfg['post_max_size'], None,
|
||||
batch_cls_labels))
|
||||
|
||||
# Merge branches results
|
||||
num_samples = len(rets[0])
|
||||
bboxes = torch.cat([ret[0] for ret in rets], dim=1)
|
||||
bboxes[..., 2] = bboxes[..., 2] - bboxes[..., 5] * 0.5
|
||||
scores = torch.cat([ret[1] for ret in rets], dim=1)
|
||||
|
||||
ret_list = []
|
||||
for i in range(num_samples):
|
||||
for k in rets[0][i].keys():
|
||||
if k == 'bboxes':
|
||||
bboxes = torch.cat([ret[i][k] for ret in rets])
|
||||
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
|
||||
bboxes = img_metas[i]['box_type_3d'](bboxes,
|
||||
self.bbox_coder.code_size)
|
||||
elif k == 'scores':
|
||||
scores = torch.cat([ret[i][k] for ret in rets])
|
||||
elif k == 'labels':
|
||||
flag = 0
|
||||
for j, num_class in enumerate(self.num_classes):
|
||||
rets[j][i][k] += flag
|
||||
flag += num_class
|
||||
labels = torch.cat([ret[i][k].int() for ret in rets])
|
||||
ret_list.append([bboxes, scores, labels])
|
||||
return ret_list
|
||||
labels = [ret[3] for ret in rets]
|
||||
flag = 0
|
||||
for i, num_class in enumerate(self.num_classes):
|
||||
labels[i] += flag
|
||||
flag += num_class
|
||||
labels = torch.cat(labels, dim=1)
|
||||
return bboxes, scores, labels
|
||||
|
|
|
@ -102,5 +102,6 @@ def mvxtwostagedetector__simple_test_pts(ctx,
|
|||
Returns:
|
||||
List: Result of model.
|
||||
"""
|
||||
bbox_preds, scores, dir_scores = self.pts_bbox_head(x)
|
||||
return bbox_preds, scores, dir_scores
|
||||
outs = self.pts_bbox_head(x)
|
||||
outs = self.pts_bbox_head.get_bboxes(*outs, img_metas, rescale=rescale)
|
||||
return outs
|
||||
|
|
|
@ -30,8 +30,7 @@ def pointpillarsscatter__forward(ctx,
|
|||
indices = indices.long()
|
||||
voxels = voxel_features.t()
|
||||
# Now scatter the blob back to the canvas.
|
||||
canvas.scatter_(
|
||||
dim=1, index=indices.expand(canvas.shape[0], -1), src=voxels)
|
||||
canvas[:, indices] = voxels
|
||||
# Undo the column stacking to final 4-dim tensor
|
||||
canvas = canvas.view(1, self.in_channels, self.ny, self.nx)
|
||||
return canvas
|
||||
|
|
|
@ -25,8 +25,9 @@ def voxelnet__simple_test(ctx,
|
|||
List: Result of model.
|
||||
"""
|
||||
x = self.extract_feat(voxels, num_points, coors, img_metas)
|
||||
bbox_preds, scores, dir_scores = self.bbox_head(x)
|
||||
return bbox_preds, scores, dir_scores
|
||||
outs = self.bbox_head(x)
|
||||
outs = self.bbox_head.get_bboxes(*outs, img_metas, rescale=rescale)
|
||||
return outs
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
|
|
|
@ -170,6 +170,7 @@ class TRTBatchedBEVNMSop(torch.autograd.Function):
|
|||
background_label_id: int = -1,
|
||||
return_index: bool = True):
|
||||
"""Forward of batched nms.
|
||||
|
||||
Args:
|
||||
ctx (Context): The context with meta information.
|
||||
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
|
||||
|
|
|
@ -224,7 +224,7 @@ centerpoint_model = dict(
|
|||
score_threshold=0.1,
|
||||
out_size_factor=4,
|
||||
voxel_size=voxel_size[:2],
|
||||
nms_type='circle',
|
||||
nms_type='rotate',
|
||||
pre_max_size=1000,
|
||||
post_max_size=83,
|
||||
nms_thr=0.2)))
|
||||
|
|
Binary file not shown.
|
@ -155,10 +155,7 @@ def test_centerpoint(backend_type: Backend):
|
|||
cfg=deploy_cfg,
|
||||
backend=deploy_cfg.backend_config.type,
|
||||
opset=deploy_cfg.onnx_config.opset_version):
|
||||
outputs = model.forward(*data)
|
||||
head = get_centerpoint_head()
|
||||
rewrite_outputs = head.get_bboxes(*[[i] for i in outputs],
|
||||
inputs['img_metas'][0])
|
||||
rewrite_outputs = model.forward(*data)
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ deploy_cfg = mmcv.Config(
|
|||
opset_version=11,
|
||||
input_shape=None,
|
||||
input_names=['voxels', 'num_points', 'coors'],
|
||||
output_names=['scores', 'bbox_preds', 'dir_scores'])))
|
||||
output_names=['bboxes', 'scores', 'labels'])))
|
||||
onnx_file = NamedTemporaryFile(suffix='.onnx').name
|
||||
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
|
||||
|
||||
|
@ -52,9 +52,9 @@ def backend_model():
|
|||
wrapper = SwitchBackendWrapper(ORTWrapper)
|
||||
wrapper.set(
|
||||
outputs={
|
||||
'scores': torch.rand(1, 18, 32, 32),
|
||||
'bbox_preds': torch.rand(1, 42, 32, 32),
|
||||
'dir_scores': torch.rand(1, 12, 32, 32)
|
||||
'bboxes': torch.rand(1, 50, 7),
|
||||
'scores': torch.rand(1, 50),
|
||||
'labels': torch.rand(1, 50)
|
||||
})
|
||||
|
||||
yield task_processor.init_backend_model([''])
|
||||
|
|
|
@ -33,15 +33,15 @@ class TestVoxelDetectionModel:
|
|||
# simplify backend inference
|
||||
cls.wrapper = SwitchBackendWrapper(ORTWrapper)
|
||||
cls.outputs = {
|
||||
'scores': torch.rand(1, 18, 32, 32),
|
||||
'bbox_preds': torch.rand(1, 42, 32, 32),
|
||||
'dir_scores': torch.rand(1, 12, 32, 32)
|
||||
'bboxes': torch.rand(1, 50, 7),
|
||||
'scores': torch.rand(1, 50),
|
||||
'labels': torch.rand(1, 50)
|
||||
}
|
||||
cls.wrapper.set(outputs=cls.outputs)
|
||||
deploy_cfg = mmcv.Config({
|
||||
'onnx_config': {
|
||||
'input_names': ['voxels', 'num_points', 'coors'],
|
||||
'output_names': ['scores', 'bbox_preds', 'dir_scores'],
|
||||
'output_names': ['bboxes', 'scores', 'labels'],
|
||||
'opset_version': 11
|
||||
},
|
||||
'backend_config': {
|
||||
|
|
Loading…
Reference in New Issue