Cathy0908 54e9571423
add BEVFormer (#203)
* add BEVFormer and benchmark
2022-10-24 17:20:12 +08:00

266 lines
9.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Alibaba, Inc. and its affiliates.
from os import path as osp
import mmcv
import numpy as np
from easycv.core.bbox import get_box_type
from easycv.datasets.shared.pipelines import Compose
class Det3dSourceBase(object):
"""Base 3D data source.
[
{'sample_idx':
'lidar_points': {'lidar_path': velodyne_path,
....
},
'annos': {'box_type_3d': (str) 'LiDAR/Camera/Depth'
'gt_bboxes_3d': <np.ndarray> (n, 7)
'gt_names': [list]
....
}
'calib': { .....}
'images': { .....}
}
]
Args:
data_root (str): Path of data root.
ann_file (str): Path of annotation file.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
classes (list[str], optional): Classes of the dataset.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to None.
box_type_3d (str, optional): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
Defaults to 'LiDAR'. Available options includes
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool, optional): Whether to filter empty GT.
Defaults to True.
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
"""
def __init__(self,
data_root,
ann_file,
pipeline=None,
classes=None,
modality=None,
box_type_3d='LiDAR',
filter_empty_gt=True,
test_mode=False):
super().__init__()
self.data_root = data_root
self.ann_file = ann_file
self.test_mode = test_mode
self.modality = modality
self.filter_empty_gt = filter_empty_gt
self.box_type_3d, self.box_mode_3d = get_box_type(box_type_3d)
self.CLASSES = self.get_classes(classes)
self.cat2id = {name: i for i, name in enumerate(self.CLASSES)}
self.data_infos = self.load_annotations(self.ann_file)
# process pipeline
if pipeline is not None:
self.pipeline = Compose(pipeline)
def load_annotations(self, ann_file):
"""Load annotations from ann_file.
Args:
ann_file (str): Path of the annotation file.
Returns:
list[dict]: List of annotations.
"""
# loading data from a file-like object needs file format
return mmcv.load(ann_file, file_format='pkl')
def get_data_info(self, index):
"""Get data info according to the given index.
Args:
index (int): Index of the sample data to get.
Returns:
dict: Data information that will be passed to the data
preprocessing pipelines. It includes the following keys:
- sample_idx (str): Sample index.
- pts_filename (str): Filename of point clouds.
- file_name (str): Filename of point clouds.
- ann_info (dict): Annotation info.
"""
info = self.data_infos[index]
sample_idx = info['sample_idx']
pts_filename = osp.join(self.data_root,
info['lidar_points']['lidar_path'])
input_dict = dict(
pts_filename=pts_filename,
sample_idx=sample_idx,
file_name=pts_filename)
if not self.test_mode:
annos = self.get_ann_info(index)
input_dict['ann_info'] = annos
if self.filter_empty_gt and ~(annos['gt_labels_3d'] != -1).any():
return None
return input_dict
def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: Annotation information consists of the following keys:
- gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
3D ground truth bboxes
- gt_labels_3d (np.ndarray): Labels of ground truths.
- gt_names (list[str]): Class names of ground truths.
"""
info = self.data_infos[index]
gt_bboxes_3d = info['annos']['gt_bboxes_3d']
gt_names_3d = info['annos']['gt_names']
gt_labels_3d = []
for cat in gt_names_3d:
if cat in self.CLASSES:
gt_labels_3d.append(self.CLASSES.index(cat))
else:
gt_labels_3d.append(-1)
gt_labels_3d = np.array(gt_labels_3d)
# Obtain original box 3d type in info file
ori_box_type_3d = info['annos']['box_type_3d']
ori_box_type_3d, _ = get_box_type(ori_box_type_3d)
# turn original box type to target box type
gt_bboxes_3d = ori_box_type_3d(
gt_bboxes_3d,
box_dim=gt_bboxes_3d.shape[-1],
origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)
anns_results = dict(
gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d,
gt_names=gt_names_3d)
return anns_results
def pre_pipeline(self, results):
"""Initialization before data preparation.
Args:
results (dict): Dict before data preprocessing.
- img_fields (list): Image fields.
- bbox3d_fields (list): 3D bounding boxes fields.
- pts_mask_fields (list): Mask fields of points.
- pts_seg_fields (list): Mask fields of point segments.
- bbox_fields (list): Fields of bounding boxes.
- mask_fields (list): Fields of masks.
- seg_fields (list): Segment fields.
- box_type_3d (str): 3D box type.
- box_mode_3d (str): 3D box mode.
"""
results['img_fields'] = []
results['bbox3d_fields'] = []
results['pts_mask_fields'] = []
results['pts_seg_fields'] = []
results['bbox_fields'] = []
results['mask_fields'] = []
results['seg_fields'] = []
results['box_type_3d'] = self.box_type_3d
results['box_mode_3d'] = self.box_mode_3d
def prepare_train_data(self, index):
"""Training data preparation.
Args:
index (int): Index for accessing the target data.
Returns:
dict: Training data dict of the corresponding index.
"""
input_dict = self.get_data_info(index)
if input_dict is None:
return None
self.pre_pipeline(input_dict)
example = self.pipeline(input_dict)
if self.filter_empty_gt and (example is None or
~(example['gt_labels_3d'] != -1).any()):
return None
return example
def prepare_test_data(self, index):
"""Prepare data for testing.
Args:
index (int): Index for accessing the target data.
Returns:
dict: Testing data dict of the corresponding index.
"""
input_dict = self.get_data_info(index)
self.pre_pipeline(input_dict)
example = self.pipeline(input_dict)
return example
def get_classes(cls, classes=None):
"""Get class names of current dataset.
Args:
classes (Sequence[str] | str): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
Return:
list[str]: A list of class names.
"""
if classes is None:
return cls.CLASSES
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
return class_names
def __len__(self):
"""Return the length of data infos.
Returns:
int: Length of data infos.
"""
return len(self.data_infos)
def __getitem__(self, idx):
"""Get item from infos according to the given index.
Returns:
dict: Data dictionary of the corresponding index.
"""
if self.test_mode:
return self.prepare_test_data(idx)
else:
return self.prepare_train_data(idx)