EasyCV/easycv/models/detection3d/detectors/base.py

112 lines
4.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Alibaba, Inc. and its affiliates.
from os import path as osp
import mmcv
import torch
from mmcv.parallel import DataContainer as DC
from easycv.core.bbox.structures import Box3DMode, Coord3DMode
from easycv.core.visualization.image_3d import show_result
from easycv.models.base import BaseModel
class Base3DDetector(BaseModel):
"""Base class for detectors."""
def forward_test(self, points, img_metas, img=None, **kwargs):
"""
Args:
points (list[torch.Tensor]): the outer list indicates test-time
augmentations and inner torch.Tensor should have a shape NxC,
which contains all points in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch
img (list[torch.Tensor], optional): the outer
list indicates test-time augmentations and inner
torch.Tensor should have a shape NxCxHxW, which contains
all images in the batch. Defaults to None.
"""
for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(points)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(points), len(img_metas)))
if num_augs == 1:
img = [img] if img is None else img
return self.simple_test(points[0], img_metas[0], img[0], **kwargs)
else:
return self.aug_test(points, img_metas, img, **kwargs)
def show_results(self, data, result, out_dir, show=False, score_thr=None):
"""Results visualization.
Args:
data (list[dict]): Input points and the information of the sample.
result (list[dict]): Prediction results.
out_dir (str): Output directory of visualization result.
show (bool, optional): Determines whether you are
going to show result by open3d.
Defaults to False.
score_thr (float, optional): Score threshold of bounding boxes.
Default to None.
"""
for batch_id in range(len(result)):
if isinstance(data['points'][0], DC):
points = data['points'][0]._data[0][batch_id].numpy()
elif mmcv.is_list_of(data['points'][0], torch.Tensor):
points = data['points'][0][batch_id]
else:
ValueError(f"Unsupported data type {type(data['points'][0])} "
f'for visualization!')
if isinstance(data['img_metas'][0], DC):
pts_filename = data['img_metas'][0]._data[0][batch_id][
'pts_filename']
box_mode_3d = data['img_metas'][0]._data[0][batch_id][
'box_mode_3d']
elif mmcv.is_list_of(data['img_metas'][0], dict):
pts_filename = data['img_metas'][0][batch_id]['pts_filename']
box_mode_3d = data['img_metas'][0][batch_id]['box_mode_3d']
else:
ValueError(
f"Unsupported data type {type(data['img_metas'][0])} "
f'for visualization!')
file_name = osp.split(pts_filename)[-1].split('.')[0]
assert out_dir is not None, 'Expect out_dir, got none.'
pred_bboxes = result[batch_id]['boxes_3d']
pred_labels = result[batch_id]['labels_3d']
if score_thr is not None:
mask = result[batch_id]['scores_3d'] > score_thr
pred_bboxes = pred_bboxes[mask]
pred_labels = pred_labels[mask]
# for now we convert points and bbox into depth mode
if (box_mode_3d == Box3DMode.CAM) or (box_mode_3d
== Box3DMode.LIDAR):
points = Coord3DMode.convert_point(points, Coord3DMode.LIDAR,
Coord3DMode.DEPTH)
pred_bboxes = Box3DMode.convert(pred_bboxes, box_mode_3d,
Box3DMode.DEPTH)
elif box_mode_3d != Box3DMode.DEPTH:
ValueError(
f'Unsupported box_mode_3d {box_mode_3d} for conversion!')
pred_bboxes = pred_bboxes.tensor.cpu().numpy()
show_result(
points,
None,
pred_bboxes,
out_dir,
file_name,
show=show,
pred_labels=pred_labels)