mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
266 lines
9.0 KiB
Python
266 lines
9.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
from os import path as osp
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
|
|
from easycv.core.bbox import get_box_type
|
|
from easycv.datasets.shared.pipelines import Compose
|
|
|
|
|
|
class Det3dSourceBase(object):
|
|
"""Base 3D data source.
|
|
|
|
[
|
|
{'sample_idx':
|
|
'lidar_points': {'lidar_path': velodyne_path,
|
|
....
|
|
},
|
|
'annos': {'box_type_3d': (str) 'LiDAR/Camera/Depth'
|
|
'gt_bboxes_3d': <np.ndarray> (n, 7)
|
|
'gt_names': [list]
|
|
....
|
|
}
|
|
'calib': { .....}
|
|
'images': { .....}
|
|
}
|
|
]
|
|
|
|
Args:
|
|
data_root (str): Path of data root.
|
|
ann_file (str): Path of annotation file.
|
|
pipeline (list[dict], optional): Pipeline used for data processing.
|
|
Defaults to None.
|
|
classes (list[str], optional): Classes of the dataset.
|
|
Defaults to None.
|
|
modality (dict, optional): Modality to specify the sensor data used
|
|
as input. Defaults to None.
|
|
box_type_3d (str, optional): Type of 3D box of this dataset.
|
|
Based on the `box_type_3d`, the dataset will encapsulate the box
|
|
to its original format then converted them to `box_type_3d`.
|
|
Defaults to 'LiDAR'. Available options includes
|
|
|
|
- 'LiDAR': Box in LiDAR coordinates.
|
|
- 'Depth': Box in depth coordinates, usually for indoor dataset.
|
|
- 'Camera': Box in camera coordinates.
|
|
filter_empty_gt (bool, optional): Whether to filter empty GT.
|
|
Defaults to True.
|
|
test_mode (bool, optional): Whether the dataset is in test mode.
|
|
Defaults to False.
|
|
"""
|
|
|
|
def __init__(self,
|
|
data_root,
|
|
ann_file,
|
|
pipeline=None,
|
|
classes=None,
|
|
modality=None,
|
|
box_type_3d='LiDAR',
|
|
filter_empty_gt=True,
|
|
test_mode=False):
|
|
super().__init__()
|
|
self.data_root = data_root
|
|
self.ann_file = ann_file
|
|
self.test_mode = test_mode
|
|
self.modality = modality
|
|
self.filter_empty_gt = filter_empty_gt
|
|
self.box_type_3d, self.box_mode_3d = get_box_type(box_type_3d)
|
|
|
|
self.CLASSES = self.get_classes(classes)
|
|
self.cat2id = {name: i for i, name in enumerate(self.CLASSES)}
|
|
self.data_infos = self.load_annotations(self.ann_file)
|
|
|
|
# process pipeline
|
|
if pipeline is not None:
|
|
self.pipeline = Compose(pipeline)
|
|
|
|
def load_annotations(self, ann_file):
|
|
"""Load annotations from ann_file.
|
|
|
|
Args:
|
|
ann_file (str): Path of the annotation file.
|
|
|
|
Returns:
|
|
list[dict]: List of annotations.
|
|
"""
|
|
# loading data from a file-like object needs file format
|
|
return mmcv.load(ann_file, file_format='pkl')
|
|
|
|
def get_data_info(self, index):
|
|
"""Get data info according to the given index.
|
|
|
|
Args:
|
|
index (int): Index of the sample data to get.
|
|
|
|
Returns:
|
|
dict: Data information that will be passed to the data
|
|
preprocessing pipelines. It includes the following keys:
|
|
|
|
- sample_idx (str): Sample index.
|
|
- pts_filename (str): Filename of point clouds.
|
|
- file_name (str): Filename of point clouds.
|
|
- ann_info (dict): Annotation info.
|
|
"""
|
|
info = self.data_infos[index]
|
|
sample_idx = info['sample_idx']
|
|
pts_filename = osp.join(self.data_root,
|
|
info['lidar_points']['lidar_path'])
|
|
|
|
input_dict = dict(
|
|
pts_filename=pts_filename,
|
|
sample_idx=sample_idx,
|
|
file_name=pts_filename)
|
|
|
|
if not self.test_mode:
|
|
annos = self.get_ann_info(index)
|
|
input_dict['ann_info'] = annos
|
|
if self.filter_empty_gt and ~(annos['gt_labels_3d'] != -1).any():
|
|
return None
|
|
return input_dict
|
|
|
|
def get_ann_info(self, index):
|
|
"""Get annotation info according to the given index.
|
|
|
|
Args:
|
|
index (int): Index of the annotation data to get.
|
|
|
|
Returns:
|
|
dict: Annotation information consists of the following keys:
|
|
|
|
- gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
|
|
3D ground truth bboxes
|
|
- gt_labels_3d (np.ndarray): Labels of ground truths.
|
|
- gt_names (list[str]): Class names of ground truths.
|
|
"""
|
|
info = self.data_infos[index]
|
|
gt_bboxes_3d = info['annos']['gt_bboxes_3d']
|
|
gt_names_3d = info['annos']['gt_names']
|
|
gt_labels_3d = []
|
|
for cat in gt_names_3d:
|
|
if cat in self.CLASSES:
|
|
gt_labels_3d.append(self.CLASSES.index(cat))
|
|
else:
|
|
gt_labels_3d.append(-1)
|
|
gt_labels_3d = np.array(gt_labels_3d)
|
|
|
|
# Obtain original box 3d type in info file
|
|
ori_box_type_3d = info['annos']['box_type_3d']
|
|
ori_box_type_3d, _ = get_box_type(ori_box_type_3d)
|
|
|
|
# turn original box type to target box type
|
|
gt_bboxes_3d = ori_box_type_3d(
|
|
gt_bboxes_3d,
|
|
box_dim=gt_bboxes_3d.shape[-1],
|
|
origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)
|
|
|
|
anns_results = dict(
|
|
gt_bboxes_3d=gt_bboxes_3d,
|
|
gt_labels_3d=gt_labels_3d,
|
|
gt_names=gt_names_3d)
|
|
return anns_results
|
|
|
|
def pre_pipeline(self, results):
|
|
"""Initialization before data preparation.
|
|
|
|
Args:
|
|
results (dict): Dict before data preprocessing.
|
|
|
|
- img_fields (list): Image fields.
|
|
- bbox3d_fields (list): 3D bounding boxes fields.
|
|
- pts_mask_fields (list): Mask fields of points.
|
|
- pts_seg_fields (list): Mask fields of point segments.
|
|
- bbox_fields (list): Fields of bounding boxes.
|
|
- mask_fields (list): Fields of masks.
|
|
- seg_fields (list): Segment fields.
|
|
- box_type_3d (str): 3D box type.
|
|
- box_mode_3d (str): 3D box mode.
|
|
"""
|
|
results['img_fields'] = []
|
|
results['bbox3d_fields'] = []
|
|
results['pts_mask_fields'] = []
|
|
results['pts_seg_fields'] = []
|
|
results['bbox_fields'] = []
|
|
results['mask_fields'] = []
|
|
results['seg_fields'] = []
|
|
results['box_type_3d'] = self.box_type_3d
|
|
results['box_mode_3d'] = self.box_mode_3d
|
|
|
|
def prepare_train_data(self, index):
|
|
"""Training data preparation.
|
|
|
|
Args:
|
|
index (int): Index for accessing the target data.
|
|
|
|
Returns:
|
|
dict: Training data dict of the corresponding index.
|
|
"""
|
|
input_dict = self.get_data_info(index)
|
|
if input_dict is None:
|
|
return None
|
|
self.pre_pipeline(input_dict)
|
|
example = self.pipeline(input_dict)
|
|
if self.filter_empty_gt and (example is None or
|
|
~(example['gt_labels_3d'] != -1).any()):
|
|
return None
|
|
return example
|
|
|
|
def prepare_test_data(self, index):
|
|
"""Prepare data for testing.
|
|
|
|
Args:
|
|
index (int): Index for accessing the target data.
|
|
|
|
Returns:
|
|
dict: Testing data dict of the corresponding index.
|
|
"""
|
|
input_dict = self.get_data_info(index)
|
|
self.pre_pipeline(input_dict)
|
|
example = self.pipeline(input_dict)
|
|
return example
|
|
|
|
def get_classes(cls, classes=None):
|
|
"""Get class names of current dataset.
|
|
|
|
Args:
|
|
classes (Sequence[str] | str): If classes is None, use
|
|
default CLASSES defined by builtin dataset. If classes is a
|
|
string, take it as a file name. The file contains the name of
|
|
classes where each line contains one class name. If classes is
|
|
a tuple or list, override the CLASSES defined by the dataset.
|
|
|
|
Return:
|
|
list[str]: A list of class names.
|
|
"""
|
|
if classes is None:
|
|
return cls.CLASSES
|
|
|
|
if isinstance(classes, str):
|
|
# take it as a file path
|
|
class_names = mmcv.list_from_file(classes)
|
|
elif isinstance(classes, (tuple, list)):
|
|
class_names = classes
|
|
else:
|
|
raise ValueError(f'Unsupported type {type(classes)} of classes.')
|
|
|
|
return class_names
|
|
|
|
def __len__(self):
|
|
"""Return the length of data infos.
|
|
|
|
Returns:
|
|
int: Length of data infos.
|
|
"""
|
|
return len(self.data_infos)
|
|
|
|
def __getitem__(self, idx):
|
|
"""Get item from infos according to the given index.
|
|
|
|
Returns:
|
|
dict: Data dictionary of the corresponding index.
|
|
"""
|
|
if self.test_mode:
|
|
return self.prepare_test_data(idx)
|
|
else:
|
|
return self.prepare_train_data(idx)
|