# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) Alibaba, Inc. and its affiliates. from os import path as osp import mmcv import torch from mmcv.parallel import DataContainer as DC from easycv.core.bbox.structures import Box3DMode, Coord3DMode from easycv.core.visualization.image_3d import show_result from easycv.models.base import BaseModel class Base3DDetector(BaseModel): """Base class for detectors.""" def forward_test(self, points, img_metas, img=None, **kwargs): """ Args: points (list[torch.Tensor]): the outer list indicates test-time augmentations and inner torch.Tensor should have a shape NxC, which contains all points in the batch. img_metas (list[list[dict]]): the outer list indicates test-time augs (multiscale, flip, etc.) and the inner list indicates images in a batch img (list[torch.Tensor], optional): the outer list indicates test-time augmentations and inner torch.Tensor should have a shape NxCxHxW, which contains all images in the batch. Defaults to None. """ for var, name in [(points, 'points'), (img_metas, 'img_metas')]: if not isinstance(var, list): raise TypeError('{} must be a list, but got {}'.format( name, type(var))) num_augs = len(points) if num_augs != len(img_metas): raise ValueError( 'num of augmentations ({}) != num of image meta ({})'.format( len(points), len(img_metas))) if num_augs == 1: img = [img] if img is None else img return self.simple_test(points[0], img_metas[0], img[0], **kwargs) else: return self.aug_test(points, img_metas, img, **kwargs) def show_results(self, data, result, out_dir, show=False, score_thr=None): """Results visualization. Args: data (list[dict]): Input points and the information of the sample. result (list[dict]): Prediction results. out_dir (str): Output directory of visualization result. show (bool, optional): Determines whether you are going to show result by open3d. Defaults to False. score_thr (float, optional): Score threshold of bounding boxes. Default to None. """ for batch_id in range(len(result)): if isinstance(data['points'][0], DC): points = data['points'][0]._data[0][batch_id].numpy() elif mmcv.is_list_of(data['points'][0], torch.Tensor): points = data['points'][0][batch_id] else: ValueError(f"Unsupported data type {type(data['points'][0])} " f'for visualization!') if isinstance(data['img_metas'][0], DC): pts_filename = data['img_metas'][0]._data[0][batch_id][ 'pts_filename'] box_mode_3d = data['img_metas'][0]._data[0][batch_id][ 'box_mode_3d'] elif mmcv.is_list_of(data['img_metas'][0], dict): pts_filename = data['img_metas'][0][batch_id]['pts_filename'] box_mode_3d = data['img_metas'][0][batch_id]['box_mode_3d'] else: ValueError( f"Unsupported data type {type(data['img_metas'][0])} " f'for visualization!') file_name = osp.split(pts_filename)[-1].split('.')[0] assert out_dir is not None, 'Expect out_dir, got none.' pred_bboxes = result[batch_id]['boxes_3d'] pred_labels = result[batch_id]['labels_3d'] if score_thr is not None: mask = result[batch_id]['scores_3d'] > score_thr pred_bboxes = pred_bboxes[mask] pred_labels = pred_labels[mask] # for now we convert points and bbox into depth mode if (box_mode_3d == Box3DMode.CAM) or (box_mode_3d == Box3DMode.LIDAR): points = Coord3DMode.convert_point(points, Coord3DMode.LIDAR, Coord3DMode.DEPTH) pred_bboxes = Box3DMode.convert(pred_bboxes, box_mode_3d, Box3DMode.DEPTH) elif box_mode_3d != Box3DMode.DEPTH: ValueError( f'Unsupported box_mode_3d {box_mode_3d} for conversion!') pred_bboxes = pred_bboxes.tensor.cpu().numpy() show_result( points, None, pred_bboxes, out_dir, file_name, show=show, pred_labels=pred_labels)