mirror of https://github.com/hero-y/BHRL
362 lines
13 KiB
Python
362 lines
13 KiB
Python
import os.path as osp
|
|
import warnings
|
|
from collections import OrderedDict
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
from mmcv.utils import print_log
|
|
from terminaltables import AsciiTable
|
|
from torch.utils.data import Dataset
|
|
|
|
from mmdet.core import eval_map, eval_recalls
|
|
from .builder import DATASETS
|
|
from .pipelines import Compose
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class CustomDataset(Dataset):
|
|
"""Custom dataset for detection.
|
|
|
|
The annotation format is shown as follows. The `ann` field is optional for
|
|
testing.
|
|
|
|
.. code-block:: none
|
|
|
|
[
|
|
{
|
|
'filename': 'a.jpg',
|
|
'width': 1280,
|
|
'height': 720,
|
|
'ann': {
|
|
'bboxes': <np.ndarray> (n, 4) in (x1, y1, x2, y2) order.
|
|
'labels': <np.ndarray> (n, ),
|
|
'bboxes_ignore': <np.ndarray> (k, 4), (optional field)
|
|
'labels_ignore': <np.ndarray> (k, 4) (optional field)
|
|
}
|
|
},
|
|
...
|
|
]
|
|
|
|
Args:
|
|
ann_file (str): Annotation file path.
|
|
pipeline (list[dict]): Processing pipeline.
|
|
classes (str | Sequence[str], optional): Specify classes to load.
|
|
If is None, ``cls.CLASSES`` will be used. Default: None.
|
|
data_root (str, optional): Data root for ``ann_file``,
|
|
``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified.
|
|
test_mode (bool, optional): If set True, annotation will not be loaded.
|
|
filter_empty_gt (bool, optional): If set true, images without bounding
|
|
boxes of the dataset's classes will be filtered out. This option
|
|
only works when `test_mode=False`, i.e., we never filter images
|
|
during tests.
|
|
"""
|
|
|
|
CLASSES = None
|
|
|
|
def __init__(self,
|
|
ann_file,
|
|
pipeline,
|
|
classes=None,
|
|
data_root=None,
|
|
img_prefix='',
|
|
seg_prefix=None,
|
|
proposal_file=None,
|
|
test_mode=False,
|
|
filter_empty_gt=True):
|
|
self.ann_file = ann_file
|
|
self.data_root = data_root
|
|
self.img_prefix = img_prefix
|
|
self.seg_prefix = seg_prefix
|
|
self.proposal_file = proposal_file
|
|
self.test_mode = test_mode
|
|
self.filter_empty_gt = filter_empty_gt
|
|
self.CLASSES = self.get_classes(classes)
|
|
|
|
# join paths if data_root is specified
|
|
if self.data_root is not None:
|
|
if not osp.isabs(self.ann_file):
|
|
self.ann_file = osp.join(self.data_root, self.ann_file)
|
|
if not (self.img_prefix is None or osp.isabs(self.img_prefix)):
|
|
self.img_prefix = osp.join(self.data_root, self.img_prefix)
|
|
if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)):
|
|
self.seg_prefix = osp.join(self.data_root, self.seg_prefix)
|
|
if not (self.proposal_file is None
|
|
or osp.isabs(self.proposal_file)):
|
|
self.proposal_file = osp.join(self.data_root,
|
|
self.proposal_file)
|
|
# load annotations (and proposals)
|
|
self.data_infos = self.load_annotations(self.ann_file)
|
|
|
|
if self.proposal_file is not None:
|
|
self.proposals = self.load_proposals(self.proposal_file)
|
|
else:
|
|
self.proposals = None
|
|
|
|
# filter images too small and containing no annotations
|
|
if not test_mode:
|
|
valid_inds = self._filter_imgs()
|
|
self.data_infos = [self.data_infos[i] for i in valid_inds]
|
|
if self.proposals is not None:
|
|
self.proposals = [self.proposals[i] for i in valid_inds]
|
|
# set group flag for the sampler
|
|
self._set_group_flag()
|
|
|
|
# processing pipeline
|
|
self.pipeline = Compose(pipeline)
|
|
|
|
def __len__(self):
|
|
"""Total number of samples of data."""
|
|
return len(self.data_infos)
|
|
|
|
def load_annotations(self, ann_file):
|
|
"""Load annotation from annotation file."""
|
|
return mmcv.load(ann_file)
|
|
|
|
def load_proposals(self, proposal_file):
|
|
"""Load proposal from proposal file."""
|
|
return mmcv.load(proposal_file)
|
|
|
|
def get_ann_info(self, idx):
|
|
"""Get annotation by index.
|
|
|
|
Args:
|
|
idx (int): Index of data.
|
|
|
|
Returns:
|
|
dict: Annotation info of specified index.
|
|
"""
|
|
|
|
return self.data_infos[idx]['ann']
|
|
|
|
def get_cat_ids(self, idx):
|
|
"""Get category ids by index.
|
|
|
|
Args:
|
|
idx (int): Index of data.
|
|
|
|
Returns:
|
|
list[int]: All categories in the image of specified index.
|
|
"""
|
|
|
|
return self.data_infos[idx]['ann']['labels'].astype(np.int).tolist()
|
|
|
|
def pre_pipeline(self, results):
|
|
"""Prepare results dict for pipeline."""
|
|
results['img_prefix'] = self.img_prefix
|
|
results['seg_prefix'] = self.seg_prefix
|
|
results['proposal_file'] = self.proposal_file
|
|
results['bbox_fields'] = []
|
|
results['mask_fields'] = []
|
|
results['seg_fields'] = []
|
|
|
|
def _filter_imgs(self, min_size=32):
|
|
"""Filter images too small."""
|
|
if self.filter_empty_gt:
|
|
warnings.warn(
|
|
'CustomDataset does not support filtering empty gt images.')
|
|
valid_inds = []
|
|
for i, img_info in enumerate(self.data_infos):
|
|
if min(img_info['width'], img_info['height']) >= min_size:
|
|
valid_inds.append(i)
|
|
return valid_inds
|
|
|
|
def _set_group_flag(self):
|
|
"""Set flag according to image aspect ratio.
|
|
|
|
Images with aspect ratio greater than 1 will be set as group 1,
|
|
otherwise group 0.
|
|
"""
|
|
self.flag = np.zeros(len(self), dtype=np.uint8)
|
|
for i in range(len(self)):
|
|
img_info = self.data_infos[i]
|
|
if img_info['width'] / img_info['height'] > 1:
|
|
self.flag[i] = 1
|
|
|
|
def _rand_another(self, idx):
|
|
"""Get another random index from the same group as the given index."""
|
|
pool = np.where(self.flag == self.flag[idx])[0]
|
|
return np.random.choice(pool)
|
|
|
|
def __getitem__(self, idx):
|
|
"""Get training/test data after pipeline.
|
|
|
|
Args:
|
|
idx (int): Index of data.
|
|
|
|
Returns:
|
|
dict: Training/test data (with annotation if `test_mode` is set \
|
|
True).
|
|
"""
|
|
|
|
if self.test_mode:
|
|
return self.prepare_test_img(idx)
|
|
while True:
|
|
data = self.prepare_train_img(idx)
|
|
if data is None:
|
|
idx = self._rand_another(idx)
|
|
continue
|
|
return data
|
|
|
|
def prepare_train_img(self, idx):
|
|
"""Get training data and annotations after pipeline.
|
|
|
|
Args:
|
|
idx (int): Index of data.
|
|
|
|
Returns:
|
|
dict: Training data and annotation after pipeline with new keys \
|
|
introduced by pipeline.
|
|
"""
|
|
|
|
img_info = self.data_infos[idx]
|
|
ann_info = self.get_ann_info(idx)
|
|
results = dict(img_info=img_info, ann_info=ann_info)
|
|
if self.proposals is not None:
|
|
results['proposals'] = self.proposals[idx]
|
|
self.pre_pipeline(results)
|
|
return self.pipeline(results)
|
|
|
|
def prepare_test_img(self, idx):
|
|
"""Get testing data after pipeline.
|
|
|
|
Args:
|
|
idx (int): Index of data.
|
|
|
|
Returns:
|
|
dict: Testing data after pipeline with new keys introduced by \
|
|
pipeline.
|
|
"""
|
|
|
|
img_info = self.data_infos[idx]
|
|
results = dict(img_info=img_info)
|
|
if self.proposals is not None:
|
|
results['proposals'] = self.proposals[idx]
|
|
self.pre_pipeline(results)
|
|
return self.pipeline(results)
|
|
|
|
@classmethod
|
|
def get_classes(cls, classes=None):
|
|
"""Get class names of current dataset.
|
|
|
|
Args:
|
|
classes (Sequence[str] | str | None): If classes is None, use
|
|
default CLASSES defined by builtin dataset. If classes is a
|
|
string, take it as a file name. The file contains the name of
|
|
classes where each line contains one class name. If classes is
|
|
a tuple or list, override the CLASSES defined by the dataset.
|
|
|
|
Returns:
|
|
tuple[str] or list[str]: Names of categories of the dataset.
|
|
"""
|
|
if classes is None:
|
|
return cls.CLASSES
|
|
|
|
if isinstance(classes, str):
|
|
# take it as a file path
|
|
class_names = mmcv.list_from_file(classes)
|
|
elif isinstance(classes, (tuple, list)):
|
|
class_names = classes
|
|
else:
|
|
raise ValueError(f'Unsupported type {type(classes)} of classes.')
|
|
|
|
return class_names
|
|
|
|
def format_results(self, results, **kwargs):
|
|
"""Place holder to format result to dataset specific output."""
|
|
|
|
def evaluate(self,
|
|
results,
|
|
metric='mAP',
|
|
logger=None,
|
|
proposal_nums=(100, 300, 1000),
|
|
iou_thr=0.5,
|
|
scale_ranges=None):
|
|
"""Evaluate the dataset.
|
|
|
|
Args:
|
|
results (list): Testing results of the dataset.
|
|
metric (str | list[str]): Metrics to be evaluated.
|
|
logger (logging.Logger | None | str): Logger used for printing
|
|
related information during evaluation. Default: None.
|
|
proposal_nums (Sequence[int]): Proposal number used for evaluating
|
|
recalls, such as recall@100, recall@1000.
|
|
Default: (100, 300, 1000).
|
|
iou_thr (float | list[float]): IoU threshold. Default: 0.5.
|
|
scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP.
|
|
Default: None.
|
|
"""
|
|
|
|
if not isinstance(metric, str):
|
|
assert len(metric) == 1
|
|
metric = metric[0]
|
|
allowed_metrics = ['mAP', 'recall']
|
|
if metric not in allowed_metrics:
|
|
raise KeyError(f'metric {metric} is not supported')
|
|
annotations = [self.get_ann_info(i) for i in range(len(self))]
|
|
eval_results = OrderedDict()
|
|
iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr
|
|
if metric == 'mAP':
|
|
assert isinstance(iou_thrs, list)
|
|
mean_aps = []
|
|
for iou_thr in iou_thrs:
|
|
print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}')
|
|
mean_ap, _ = eval_map(
|
|
results,
|
|
annotations,
|
|
scale_ranges=scale_ranges,
|
|
iou_thr=iou_thr,
|
|
dataset=self.CLASSES,
|
|
logger=logger)
|
|
mean_aps.append(mean_ap)
|
|
eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3)
|
|
eval_results['mAP'] = sum(mean_aps) / len(mean_aps)
|
|
elif metric == 'recall':
|
|
gt_bboxes = [ann['bboxes'] for ann in annotations]
|
|
recalls = eval_recalls(
|
|
gt_bboxes, results, proposal_nums, iou_thr, logger=logger)
|
|
for i, num in enumerate(proposal_nums):
|
|
for j, iou in enumerate(iou_thrs):
|
|
eval_results[f'recall@{num}@{iou}'] = recalls[i, j]
|
|
if recalls.shape[1] > 1:
|
|
ar = recalls.mean(axis=1)
|
|
for i, num in enumerate(proposal_nums):
|
|
eval_results[f'AR@{num}'] = ar[i]
|
|
return eval_results
|
|
|
|
def __repr__(self):
|
|
"""Print the number of instance number."""
|
|
dataset_type = 'Test' if self.test_mode else 'Train'
|
|
result = (f'\n{self.__class__.__name__} {dataset_type} dataset '
|
|
f'with number of images {len(self)}, '
|
|
f'and instance counts: \n')
|
|
if self.CLASSES is None:
|
|
result += 'Category names are not provided. \n'
|
|
return result
|
|
instance_count = np.zeros(len(self.CLASSES) + 1).astype(int)
|
|
# count the instance number in each image
|
|
for idx in range(len(self)):
|
|
label = self.get_ann_info(idx)['labels']
|
|
unique, counts = np.unique(label, return_counts=True)
|
|
if len(unique) > 0:
|
|
# add the occurrence number to each class
|
|
instance_count[unique] += counts
|
|
else:
|
|
# background is the last index
|
|
instance_count[-1] += 1
|
|
# create a table with category count
|
|
table_data = [['category', 'count'] * 5]
|
|
row_data = []
|
|
for cls, count in enumerate(instance_count):
|
|
if cls < len(self.CLASSES):
|
|
row_data += [f'{cls} [{self.CLASSES[cls]}]', f'{count}']
|
|
else:
|
|
# add the background number
|
|
row_data += ['-1 background', f'{count}']
|
|
if len(row_data) == 10:
|
|
table_data.append(row_data)
|
|
row_data = []
|
|
|
|
table = AsciiTable(table_data)
|
|
result += table.table
|
|
return result
|