mirror of https://github.com/alibaba/EasyCV.git
112 lines
4.8 KiB
Python
112 lines
4.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
from os import path as osp
|
|
|
|
import mmcv
|
|
import torch
|
|
from mmcv.parallel import DataContainer as DC
|
|
|
|
from easycv.core.bbox.structures import Box3DMode, Coord3DMode
|
|
from easycv.core.visualization.image_3d import show_result
|
|
from easycv.models.base import BaseModel
|
|
|
|
|
|
class Base3DDetector(BaseModel):
|
|
"""Base class for detectors."""
|
|
|
|
def forward_test(self, points, img_metas, img=None, **kwargs):
|
|
"""
|
|
Args:
|
|
points (list[torch.Tensor]): the outer list indicates test-time
|
|
augmentations and inner torch.Tensor should have a shape NxC,
|
|
which contains all points in the batch.
|
|
img_metas (list[list[dict]]): the outer list indicates test-time
|
|
augs (multiscale, flip, etc.) and the inner list indicates
|
|
images in a batch
|
|
img (list[torch.Tensor], optional): the outer
|
|
list indicates test-time augmentations and inner
|
|
torch.Tensor should have a shape NxCxHxW, which contains
|
|
all images in the batch. Defaults to None.
|
|
"""
|
|
for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
|
|
if not isinstance(var, list):
|
|
raise TypeError('{} must be a list, but got {}'.format(
|
|
name, type(var)))
|
|
|
|
num_augs = len(points)
|
|
if num_augs != len(img_metas):
|
|
raise ValueError(
|
|
'num of augmentations ({}) != num of image meta ({})'.format(
|
|
len(points), len(img_metas)))
|
|
|
|
if num_augs == 1:
|
|
img = [img] if img is None else img
|
|
return self.simple_test(points[0], img_metas[0], img[0], **kwargs)
|
|
else:
|
|
return self.aug_test(points, img_metas, img, **kwargs)
|
|
|
|
def show_results(self, data, result, out_dir, show=False, score_thr=None):
|
|
"""Results visualization.
|
|
|
|
Args:
|
|
data (list[dict]): Input points and the information of the sample.
|
|
result (list[dict]): Prediction results.
|
|
out_dir (str): Output directory of visualization result.
|
|
show (bool, optional): Determines whether you are
|
|
going to show result by open3d.
|
|
Defaults to False.
|
|
score_thr (float, optional): Score threshold of bounding boxes.
|
|
Default to None.
|
|
"""
|
|
for batch_id in range(len(result)):
|
|
if isinstance(data['points'][0], DC):
|
|
points = data['points'][0]._data[0][batch_id].numpy()
|
|
elif mmcv.is_list_of(data['points'][0], torch.Tensor):
|
|
points = data['points'][0][batch_id]
|
|
else:
|
|
ValueError(f"Unsupported data type {type(data['points'][0])} "
|
|
f'for visualization!')
|
|
if isinstance(data['img_metas'][0], DC):
|
|
pts_filename = data['img_metas'][0]._data[0][batch_id][
|
|
'pts_filename']
|
|
box_mode_3d = data['img_metas'][0]._data[0][batch_id][
|
|
'box_mode_3d']
|
|
elif mmcv.is_list_of(data['img_metas'][0], dict):
|
|
pts_filename = data['img_metas'][0][batch_id]['pts_filename']
|
|
box_mode_3d = data['img_metas'][0][batch_id]['box_mode_3d']
|
|
else:
|
|
ValueError(
|
|
f"Unsupported data type {type(data['img_metas'][0])} "
|
|
f'for visualization!')
|
|
file_name = osp.split(pts_filename)[-1].split('.')[0]
|
|
|
|
assert out_dir is not None, 'Expect out_dir, got none.'
|
|
|
|
pred_bboxes = result[batch_id]['boxes_3d']
|
|
pred_labels = result[batch_id]['labels_3d']
|
|
|
|
if score_thr is not None:
|
|
mask = result[batch_id]['scores_3d'] > score_thr
|
|
pred_bboxes = pred_bboxes[mask]
|
|
pred_labels = pred_labels[mask]
|
|
|
|
# for now we convert points and bbox into depth mode
|
|
if (box_mode_3d == Box3DMode.CAM) or (box_mode_3d
|
|
== Box3DMode.LIDAR):
|
|
points = Coord3DMode.convert_point(points, Coord3DMode.LIDAR,
|
|
Coord3DMode.DEPTH)
|
|
pred_bboxes = Box3DMode.convert(pred_bboxes, box_mode_3d,
|
|
Box3DMode.DEPTH)
|
|
elif box_mode_3d != Box3DMode.DEPTH:
|
|
ValueError(
|
|
f'Unsupported box_mode_3d {box_mode_3d} for conversion!')
|
|
pred_bboxes = pred_bboxes.tensor.cpu().numpy()
|
|
show_result(
|
|
points,
|
|
None,
|
|
pred_bboxes,
|
|
out_dir,
|
|
file_name,
|
|
show=show,
|
|
pred_labels=pred_labels)
|