mmfewshot/mmfewshot/detection/datasets/base.py

548 lines
24 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import json
import os.path as osp
import warnings
from typing import Dict, List, Optional, Sequence, Union
import numpy as np
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset
from mmdet.datasets.pipelines import Compose
from terminaltables import AsciiTable
from mmfewshot.utils import get_root_logger
from .utils import NumpyEncoder
@DATASETS.register_module()
class BaseFewShotDataset(CustomDataset):
"""Base dataset for few shot detection.
The main differences with normal detection dataset fall in two aspects.
- It allows to specify single (used in normal dataset) or multiple
(used in query-support dataset) pipelines for data processing.
- It supports to control the maximum number of instances of each class
when loading the annotation file.
The annotation format is shown as follows. The `ann` field
is optional for testing.
.. code-block:: none
[
{
'id': '0000001'
'filename': 'a.jpg',
'width': 1280,
'height': 720,
'ann': {
'bboxes': <np.ndarray> (n, 4) in (x1, y1, x2, y2) order.
'labels': <np.ndarray> (n, ),
'bboxes_ignore': <np.ndarray> (k, 4), (optional field)
'labels_ignore': <np.ndarray> (k, 4) (optional field)
}
},
...
]
Args:
ann_cfg (list[dict]): Annotation config support two type of config.
- loading annotation from common ann_file of dataset
with or without specific classes.
example:dict(type='ann_file', ann_file='path/to/ann_file',
ann_classes=['dog', 'cat'])
- loading annotation from a json file saved by dataset.
example:dict(type='saved_dataset', ann_file='path/to/ann_file')
classes (str | Sequence[str] | None): Classes for model training and
provide fixed label for each class.
pipeline (list[dict] | None): Config to specify processing pipeline.
Used in normal dataset. Default: None.
multi_pipelines (dict[list[dict]]): Config to specify
data pipelines for corresponding data flow.
For example, query and support data
can be processed with two different pipelines, the dict
should contain two keys like:
- query (list[dict]): Config for query-data
process pipeline.
- support (list[dict]): Config for support-data
process pipeline.
data_root (str | None): Data root for ``ann_cfg``, `img_prefix``,
``seg_prefix``, ``proposal_file`` if specified. Default: None.
test_mode (bool): If set True, annotation will not be loaded.
Default: False.
filter_empty_gt (bool): If set true, images without bounding
boxes of the dataset's classes will be filtered out. This option
only works when `test_mode=False`, i.e., we never filter images
during tests. Default: True.
min_bbox_size (int | float | None): The minimum size of bounding
boxes in the images. If the size of a bounding box is less than
``min_bbox_size``, it would be added to ignored field.
Default: None.
ann_shot_filter (dict | None): Used to specify the class and the
corresponding maximum number of instances when loading
the annotation file. For example: {'dog': 10, 'person': 5}.
If set it as None, all annotation from ann file would be loaded.
Default: None.
instance_wise (bool): If set true, `self.data_infos`
would change to instance-wise, which means if the annotation
of single image has more than one instance, the annotation would be
split to num_instances items. Often used in support datasets,
Default: False.
dataset_name (str | None): Name of dataset to display. For example:
'train_dataset' or 'query_dataset'. Default: None.
"""
CLASSES = None
def __init__(self,
ann_cfg: List[Dict],
classes: Union[str, Sequence[str], None],
pipeline: Optional[List[Dict]] = None,
multi_pipelines: Optional[Dict[str, List[Dict]]] = None,
data_root: Optional[str] = None,
img_prefix: str = '',
seg_prefix: Optional[str] = None,
proposal_file: Optional[str] = None,
test_mode: bool = False,
filter_empty_gt: bool = True,
min_bbox_size: Optional[Union[int, float]] = None,
ann_shot_filter: Optional[Dict] = None,
instance_wise: bool = False,
dataset_name: Optional[str] = None) -> None:
self.data_root = data_root
self.img_prefix = img_prefix
self.seg_prefix = seg_prefix
self.proposal_file = proposal_file
self.test_mode = test_mode
self.filter_empty_gt = filter_empty_gt
if classes is not None:
self.CLASSES = self.get_classes(classes)
self.instance_wise = instance_wise
# set dataset name
if dataset_name is None:
self.dataset_name = 'Test dataset' \
if test_mode else 'Train dataset'
else:
self.dataset_name = dataset_name
# join paths if data_root is specified
if self.data_root is not None:
if not (self.img_prefix is None or osp.isabs(self.img_prefix)):
self.img_prefix = osp.join(self.data_root, self.img_prefix)
if not (self.proposal_file is None
or osp.isabs(self.proposal_file)):
self.proposal_file = osp.join(self.data_root,
self.proposal_file)
self.ann_cfg = copy.deepcopy(ann_cfg)
self.data_infos = self.ann_cfg_parser(ann_cfg)
assert self.data_infos is not None, \
f'{self.dataset_name} : none annotation loaded.'
# load proposal file
if self.proposal_file is not None:
self.proposals = self.load_proposals(self.proposal_file)
else:
self.proposals = None
# filter images too small and containing no annotations
if not test_mode:
# filter bbox smaller than the min_bbox_size
if min_bbox_size:
self.data_infos = self._filter_bboxs(min_bbox_size)
# filter images
valid_inds = self._filter_imgs()
self.data_infos = [self.data_infos[i] for i in valid_inds]
if self.proposals is not None:
self.proposals = [self.proposals[i] for i in valid_inds]
# filter annotations by ann_shot_filter
if ann_shot_filter is not None:
if isinstance(ann_shot_filter, dict):
for class_name in list(ann_shot_filter.keys()):
assert class_name in self.CLASSES, \
f'{self.dataset_name} : class ' \
f'{class_name} in ann_shot_filter not in CLASSES.'
else:
raise TypeError('ann_shot_filter only support dict')
self.ann_shot_filter = ann_shot_filter
self.data_infos = self._filter_annotations(
self.data_infos, self.ann_shot_filter)
# instance_wise will make each data info only contain one
# annotation otherwise all annotation from same image will
# be checked and merged.
if self.instance_wise:
instance_wise_data_infos = []
for data_info in self.data_infos:
num_instance = data_info['ann']['labels'].size
# split annotations
if num_instance > 1:
for i in range(data_info['ann']['labels'].size):
tmp_data_info = copy.deepcopy(data_info)
tmp_data_info['ann']['labels'] = np.expand_dims(
data_info['ann']['labels'][i], axis=0)
tmp_data_info['ann']['bboxes'] = np.expand_dims(
data_info['ann']['bboxes'][i, :], axis=0)
instance_wise_data_infos.append(tmp_data_info)
else:
instance_wise_data_infos.append(data_info)
self.data_infos = instance_wise_data_infos
# merge different annotations with the same image
else:
merge_data_dict = {}
for i, data_info in enumerate(self.data_infos):
# merge data_info with the same image id
if merge_data_dict.get(data_info['id'], None) is None:
merge_data_dict[data_info['id']] = data_info
else:
ann_a = merge_data_dict[data_info['id']]['ann']
ann_b = data_info['ann']
merge_dat_info = {
'bboxes':
np.concatenate((ann_a['bboxes'], ann_b['bboxes'])),
'labels':
np.concatenate((ann_a['labels'], ann_b['labels'])),
}
# merge `bboxes_ignore`
if ann_a.get('bboxes_ignore', None) is not None:
if not (ann_a['bboxes_ignore']
== ann_b['bboxes_ignore']).all():
merge_dat_info['bboxes_ignore'] = \
np.concatenate((ann_a['bboxes_ignore'],
ann_b['bboxes_ignore']))
merge_dat_info['labels_ignore'] = \
np.concatenate((ann_a['labels_ignore'],
ann_b['labels_ignore']))
merge_data_dict[
data_info['id']]['ann'] = merge_dat_info
self.data_infos = [
merge_data_dict[key] for key in merge_data_dict.keys()
]
# set group flag for the sampler
self._set_group_flag()
assert pipeline is None or multi_pipelines is None, \
f'{self.dataset_name} : can not assign pipeline ' \
f'or multi_pipelines simultaneously'
# processing pipeline if there are two pipeline the
# pipeline will be determined by key name of query or support
if multi_pipelines is not None:
assert isinstance(multi_pipelines, dict), \
f'{self.dataset_name} : multi_pipelines is type of dict'
self.multi_pipelines = {}
for key in multi_pipelines.keys():
self.multi_pipelines[key] = Compose(multi_pipelines[key])
elif pipeline is not None:
assert isinstance(pipeline, list), \
f'{self.dataset_name} : pipeline is type of list'
self.pipeline = Compose(pipeline)
else:
raise ValueError('missing pipeline or multi_pipelines')
# show dataset annotation usage
logger = get_root_logger()
logger.info(self.__repr__())
def ann_cfg_parser(self, ann_cfg: List[Dict]) -> List[Dict]:
"""Parse annotation config to annotation information.
Args:
ann_cfg (list[dict]): Annotation config support two type of config.
- 'ann_file': loading annotation from common ann_file of
dataset. example: dict(type='ann_file',
ann_file='path/to/ann_file', ann_classes=['dog', 'cat'])
- 'saved_dataset': loading annotation from saved dataset.
example:dict(type='saved_dataset',
ann_file='path/to/ann_file')
Returns:
list[dict]: Annotation information.
"""
# join paths if data_root is specified
if self.data_root is not None:
for i in range(len(ann_cfg)):
if not osp.isabs(ann_cfg[i]['ann_file']):
ann_cfg[i]['ann_file'] = \
osp.join(self.data_root, ann_cfg[i]['ann_file'])
# ann_cfg must be list
assert isinstance(ann_cfg, list), \
f'{self.dataset_name} : ann_cfg should be type of list.'
# check type of ann_cfg
for ann_cfg_ in ann_cfg:
assert isinstance(ann_cfg_, dict), \
f'{self.dataset_name} : ann_cfg should be list of dict.'
assert ann_cfg_['type'] in ['ann_file', 'saved_dataset'], \
f'{self.dataset_name} : ann_cfg only support type of ' \
f'ann_file and saved_dataset'
return self.load_annotations(ann_cfg)
def get_ann_info(self, idx: int) -> Dict:
"""Get annotation by index.
When override this function please make sure same annotations are used
during the whole training.
Args:
idx (int): Index of data.
Returns:
dict: Annotation info of specified index.
"""
return copy.deepcopy(self.data_infos[idx]['ann'])
def prepare_train_img(self,
idx: int,
pipeline_key: Optional[str] = None,
gt_idx: Optional[List[int]] = None) -> Dict:
"""Get training data and annotations after pipeline.
Args:
idx (int): Index of data.
pipeline_key (str): Name of pipeline
gt_idx (list[int]): Index of used annotation.
Returns:
dict: Training data and annotation after pipeline with new keys \
introduced by pipeline.
"""
img_info = self.data_infos[idx]
ann_info = self.get_ann_info(idx)
# select annotation in `gt_idx`
if gt_idx is not None:
selected_ann_info = {
'bboxes': ann_info['bboxes'][gt_idx],
'labels': ann_info['labels'][gt_idx]
}
# keep pace with new annotations
new_img_info = copy.deepcopy(img_info)
new_img_info['ann'] = selected_ann_info
results = dict(img_info=new_img_info, ann_info=selected_ann_info)
# use all annotations
else:
results = dict(img_info=copy.deepcopy(img_info), ann_info=ann_info)
if self.proposals is not None:
results['proposals'] = self.proposals[idx]
self.pre_pipeline(results)
if pipeline_key is None:
return self.pipeline(results)
else:
return self.multi_pipelines[pipeline_key](results)
def _filter_annotations(self, data_infos: List[Dict],
ann_shot_filter: Dict) -> List[Dict]:
"""Filter out extra annotations of specific class, while annotations of
classes not in filter remain unchanged and the ignored annotations will
be removed.
Args:
data_infos (list[dict]): Annotation infos.
ann_shot_filter (dict): Specific which class and how many
instances of each class to load from annotation file.
For example: {'dog': 10, 'cat': 10, 'person': 5}
Returns:
list[dict]: Annotation infos where number of specified class
shots less than or equal to predefined number.
"""
if ann_shot_filter is None:
return data_infos
# build instance indices of (img_id, gt_idx)
filter_instances = {key: [] for key in ann_shot_filter.keys()}
keep_instances_indices = []
for idx, data_info in enumerate(data_infos):
ann = data_info['ann']
for i in range(ann['labels'].shape[0]):
instance_class_name = self.CLASSES[ann['labels'][i]]
# only filter instance from the filter class
if instance_class_name in ann_shot_filter.keys():
filter_instances[instance_class_name].append((idx, i))
# skip the class not in the filter
else:
keep_instances_indices.append((idx, i))
# filter extra shots
for class_name in ann_shot_filter.keys():
num_shots = ann_shot_filter[class_name]
instance_indices = filter_instances[class_name]
if num_shots == 0:
continue
# random sample from all instances
if len(instance_indices) > num_shots:
random_select = np.random.choice(
len(instance_indices), num_shots, replace=False)
keep_instances_indices += \
[instance_indices[i] for i in random_select]
# number of available shots less than the predefined number,
# which may cause the performance degradation
else:
# check the number of instance
if len(instance_indices) < num_shots:
warnings.warn(f'number of {class_name} instance is '
f'{len(instance_indices)} which is '
f'less than predefined shots {num_shots}.')
keep_instances_indices += instance_indices
# keep the selected annotations and remove the undesired annotations
new_data_infos = []
for idx, data_info in enumerate(data_infos):
selected_instance_indices = \
sorted([instance[1] for instance in keep_instances_indices
if instance[0] == idx])
if len(selected_instance_indices) == 0:
continue
ann = data_info['ann']
selected_ann = dict(
bboxes=ann['bboxes'][selected_instance_indices],
labels=ann['labels'][selected_instance_indices],
)
new_data_infos.append(
dict(
id=data_info['id'],
filename=data_info['filename'],
width=data_info['width'],
height=data_info['height'],
ann=selected_ann))
return new_data_infos
def _filter_bboxs(self, min_bbox_size: int) -> List[Dict]:
new_data_infos = []
for data_info in self.data_infos:
ann = data_info['ann']
keep_idx, ignore_idx = [], []
for i in range(ann['bboxes'].shape[0]):
bbox = ann['bboxes'][i]
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
# check bbox size
if w < min_bbox_size or h < min_bbox_size:
ignore_idx.append(i)
else:
keep_idx.append(i)
# remove undesired bbox
if len(ignore_idx) > 0:
bboxes_ignore = ann.get('bboxes_ignore', np.zeros((0, 4)))
labels_ignore = ann.get('labels_ignore', np.zeros((0, )))
new_bboxes_ignore = ann['bboxes'][ignore_idx]
new_labels_ignore = ann['labels'][ignore_idx]
bboxes_ignore = np.concatenate(
(bboxes_ignore, new_bboxes_ignore))
labels_ignore = np.concatenate(
(labels_ignore, new_labels_ignore))
data_info.update(
ann=dict(
bboxes=ann['bboxes'][keep_idx],
labels=ann['labels'][keep_idx],
bboxes_ignore=bboxes_ignore,
labels_ignore=labels_ignore))
new_data_infos.append(data_info)
return new_data_infos
def _set_group_flag(self) -> None:
"""Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0. In few shot setting, the limited number of images
might cause some mini-batch always sample a certain number of images
and thus not fully shuffle the data, which may degrade the performance.
Therefore, all flags are simply set to 0.
"""
self.flag = np.zeros(len(self), dtype=np.uint8)
def load_annotations_saved(self, ann_file: str) -> List[Dict]:
"""Load data_infos from saved json."""
with open(ann_file) as f:
data_infos = json.load(f)
# record the index of meta info
meta_idx = None
for i, data_info in enumerate(data_infos):
# check the meta info CLASSES and img_prefix saved in json
if 'CLASSES' in data_info.keys():
assert self.CLASSES == tuple(data_info['CLASSES']), \
f'{self.dataset_name} : class labels mismatch.'
assert self.img_prefix == data_info['img_prefix'], \
f'{self.dataset_name} : image prefix mismatch.'
meta_idx = i
# skip the meta info
continue
# convert annotations from list into numpy array
for k in data_info['ann']:
assert isinstance(data_info['ann'][k], list)
# load bboxes and bboxes_ignore
if 'bboxes' in k:
# bboxes_ignore can be empty
if len(data_info['ann'][k]) == 0:
data_info['ann'][k] = np.zeros((0, 4))
else:
data_info['ann'][k] = \
np.array(data_info['ann'][k], dtype=np.float32)
# load labels and labels_ignore
elif 'labels' in k:
# labels_ignore can be empty
if len(data_info['ann'][k]) == 0:
data_info['ann'][k] = np.zeros((0, ))
else:
data_info['ann'][k] = \
np.array(data_info['ann'][k], dtype=np.int64)
else:
raise KeyError(f'unsupported key {k} in ann field')
# remove meta info
if meta_idx is not None:
data_infos.pop(meta_idx)
return data_infos
def save_data_infos(self, output_path: str) -> None:
"""Save data_infos into json."""
# numpy array will be saved as list in the json
meta_info = [{'CLASSES': self.CLASSES, 'img_prefix': self.img_prefix}]
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(
meta_info + self.data_infos,
f,
ensure_ascii=False,
indent=4,
cls=NumpyEncoder)
def __repr__(self) -> str:
"""Print the number of instances of each class."""
result = (f'\n{self.__class__.__name__} {self.dataset_name} '
f'with number of images {len(self)}, '
f'and instance counts: \n')
if self.CLASSES is None:
result += 'Category names are not provided. \n'
return result
instance_count = np.zeros(len(self.CLASSES) + 1).astype(int)
# count the instance number in each image
for idx in range(len(self)):
label = self.get_ann_info(idx)['labels']
unique, counts = np.unique(label, return_counts=True)
if len(unique) > 0:
# add the occurrence number to each class
instance_count[unique] += counts
else:
# background is the last index
instance_count[-1] += 1
# create a table with category count
table_data = [['category', 'count'] * 5]
row_data = []
for cls, count in enumerate(instance_count):
if cls < len(self.CLASSES):
row_data += [f'{cls} [{self.CLASSES[cls]}]', f'{count}']
else:
# add the background number
row_data += ['-1 background', f'{count}']
if len(row_data) == 10:
table_data.append(row_data)
row_data = []
if len(row_data) != 0:
table_data.append(row_data)
table = AsciiTable(table_data)
result += table.table
return result