EasyCV/easycv/datasets/detection/data_sources/coco.py

415 lines
15 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os
import numpy as np
from xtcocotools.coco import COCO
from easycv.datasets.registry import DATASOURCES, PIPELINES
from easycv.datasets.shared.pipelines import Compose
from easycv.datasets.utils.download_data.download_coco import (
check_data_exists, download_coco)
from easycv.framework.errors import TypeError
from easycv.utils.registry import build_from_cfg
@DATASOURCES.register_module
class DetSourceCoco(object):
"""
coco data source
"""
def __init__(self,
ann_file,
img_prefix,
pipeline,
test_mode=False,
filter_empty_gt=False,
classes=None,
iscrowd=False):
"""
Args:
ann_file: Path of annotation file.
img_prefix: coco path prefix
test_mode (bool, optional): If set True, `self._filter_imgs` will not works.
filter_empty_gt (bool, optional): If set true, images without bounding
boxes of the dataset's classes will be filtered out. This option
only works when `test_mode=False`, i.e., we never filter images
during tests.
iscrowd: when traing setted as False, when val setted as True
"""
self.ann_file = ann_file
self.img_prefix = img_prefix
self.filter_empty_gt = filter_empty_gt
self.CLASSES = classes
# load annotations (and proposals)
self.data_infos = self.load_annotations(self.ann_file)
self.test_mode = test_mode
if not test_mode:
valid_inds = self._filter_imgs()
self.data_infos = [self.data_infos[i] for i in valid_inds]
self._set_group_flag()
self.iscrowd = iscrowd
self.max_labels_num = 120
transforms = []
for transform in pipeline:
if isinstance(transform, dict):
transform = build_from_cfg(transform, PIPELINES)
transforms.append(transform)
elif callable(transform):
transforms.append(transform)
else:
raise TypeError('transform must be callable or a dict')
self.pipeline = Compose(transforms)
def __len__(self):
"""Total number of samples of data."""
return len(self.data_infos)
def load_annotations(self, ann_file):
"""Load annotation from COCO style annotation file.
Args:
ann_file (str): Path of annotation file.
Returns:
list[dict]: Annotation info from COCO api.
"""
self.coco = COCO(ann_file)
# The order of returned `cat_ids` will not
# change with the order of the CLASSES
self.cat_ids = self.coco.getCatIds(catNms=self.CLASSES)
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.img_ids = self.coco.getImgIds()
data_infos = []
total_ann_ids = []
for i in self.img_ids:
info = self.coco.loadImgs([i])[0]
info['filename'] = info['file_name']
data_infos.append(info)
ann_ids = self.coco.getAnnIds(imgIds=[i])
total_ann_ids.extend(ann_ids)
assert len(set(total_ann_ids)) == len(
total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
return data_infos
def get_ann_info(self, idx):
"""Get COCO annotation by index.
Args:
idx (int): Index of data.
Returns:
dict: Annotation info of specified index.
"""
img_id = self.data_infos[idx]['id']
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
ann_info = self.coco.loadAnns(ann_ids)
return self._parse_ann_info(self.data_infos[idx], ann_info)
def get_cat_ids(self, idx):
"""Get COCO category ids by index.
Args:
idx (int): Index of data.
Returns:
list[int]: All categories in the image of specified index.
"""
img_id = self.data_infos[idx]['id']
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
ann_info = self.coco.loadAnns(ann_ids)
return [ann['category_id'] for ann in ann_info]
def _filter_imgs(self, min_size=32):
"""Filter images too small or without ground truths."""
valid_inds = []
# obtain images that contain annotation
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
# obtain images that contain annotations of the required categories
ids_in_cat = set()
for i, class_id in enumerate(self.cat_ids):
ids_in_cat |= set(self.coco.catToImgs[class_id])
# merge the image id sets of the two conditions and use the merged set
# to filter out images if self.filter_empty_gt=True
ids_in_cat &= ids_with_ann
valid_img_ids = []
for i, img_info in enumerate(self.data_infos):
img_id = self.img_ids[i]
if self.filter_empty_gt and img_id not in ids_in_cat:
continue
if min(img_info['width'], img_info['height']) >= min_size:
valid_inds.append(i)
valid_img_ids.append(img_id)
self.img_ids = valid_img_ids
return valid_inds
def _set_group_flag(self):
"""Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
self.flag = np.zeros(len(self), dtype=np.uint8)
for i in range(len(self)):
img_info = self.data_infos[i]
if img_info['width'] / img_info['height'] > 1:
self.flag[i] = 1
def _parse_ann_info(self, img_info, ann_info):
"""Parse bbox and mask annotation.
Args:
ann_info (list[dict]): Annotation info of an image.
with_mask (bool): Whether to parse mask annotations.
Returns:
dict: A dict containing the following keys: bboxes, bboxes_ignore,\
labels, masks, seg_map. "masks" are raw annotations and not \
decoded into binary masks.
"""
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
gt_masks_ann = []
groundtruth_is_crowd = []
for i, ann in enumerate(ann_info):
if ann.get('ignore', False):
continue
if ann.get('iscrowd', False) and (
not self.iscrowd): # while training, skip iscrowd
continue
x1, y1, w, h = ann['bbox']
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if ann['area'] <= 0 or w < 1 or h < 1:
continue
if ann['category_id'] not in self.cat_ids:
continue
bbox = [x1, y1, x1 + w, y1 + h]
if ann.get('iscrowd', False):
gt_bboxes.append(
bbox
) # add crowded gt bboxes when eval, but not needed in training
gt_labels.append(self.cat2label[ann['category_id']])
gt_masks_ann.append(ann.get('segmentation', None))
gt_bboxes_ignore.append(bbox)
groundtruth_is_crowd.append(1)
else:
gt_bboxes.append(bbox)
gt_labels.append(self.cat2label[ann['category_id']])
gt_masks_ann.append(ann.get('segmentation', None))
groundtruth_is_crowd.append(0)
if gt_bboxes:
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
gt_labels = np.array(gt_labels, dtype=np.int64)
groundtruth_is_crowd = np.array(
groundtruth_is_crowd, dtype=np.int8)
else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
gt_labels = np.array([], dtype=np.int64)
groundtruth_is_crowd = np.array([], dtype=np.int8)
if gt_bboxes_ignore:
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
else:
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
seg_map = img_info['filename'].replace('jpg', 'png')
ann = dict(
bboxes=gt_bboxes,
labels=gt_labels,
groundtruth_is_crowd=groundtruth_is_crowd,
bboxes_ignore=gt_bboxes_ignore,
masks=gt_masks_ann,
seg_map=seg_map)
return ann
def xyxy2xywh(self, bbox):
"""Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO
evaluation.
Args:
bbox (numpy.ndarray): The bounding boxes, shape (4, ), in
``xyxy`` order.
Returns:
list[float]: The converted bounding boxes, in ``xywh`` order.
"""
_bbox = bbox.tolist()
return [
_bbox[0],
_bbox[1],
_bbox[2] - _bbox[0],
_bbox[3] - _bbox[1],
]
def _proposal2json(self, results):
"""Convert proposal results to COCO json style."""
json_results = []
for idx in range(len(self)):
img_id = self.img_ids[idx]
bboxes = results[idx]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = self.xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = 1
json_results.append(data)
return json_results
def _det2json(self, results):
"""Convert detection results to COCO json style."""
json_results = []
for idx in range(len(self)):
img_id = self.img_ids[idx]
result = results[idx]
for label in range(len(result)):
bboxes = result[label]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = self.xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = self.cat_ids[label]
json_results.append(data)
return json_results
def pre_pipeline(self, results):
"""Prepare results dict for pipeline."""
results['img_prefix'] = self.img_prefix
results['bbox_fields'] = []
results['mask_fields'] = []
results['seg_fields'] = []
def prepare_train_img(self, idx):
"""Get training data and annotations after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys \
introduced by pipeline.
"""
img_info = self.data_infos[idx]
ann_info = self.get_ann_info(idx)
results = dict(img_info=img_info, ann_info=ann_info)
self.pre_pipeline(results)
return self.pipeline(results)
def __getitem__(self, idx):
"""Get training/test data after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Training/test data (with annotation if `test_mode` is set \
True).
"""
while True:
data = self.prepare_train_img(idx)
if data is None:
idx = self._rand_another(idx)
continue
return data
def _rand_another(self, idx):
"""Get another random index from the same group as the given index."""
pool = np.where(self.flag == self.flag[idx])[0]
return np.random.choice(pool)
@DATASOURCES.register_module
class DetSourceCoco2017(DetSourceCoco):
"""
coco2017 data source
"""
def __init__(self,
pipeline,
path=None,
download=True,
split='train',
test_mode=False,
filter_empty_gt=False,
classes=None,
iscrowd=False):
"""
Args:
path: This parameter is optional. If download is True and path is not provided,
a temporary directory is automatically created for downloading
download: If the value is True, the file is automatically downloaded to the path directory.
If False, automatic download is not supported and data in the path is used
split: train or val
test_mode (bool, optional): If set True, `self._filter_imgs` will not works.
filter_empty_gt (bool, optional): If set true, images without bounding
boxes of the dataset's classes will be filtered out. This option
only works when `test_mode=False`, i.e., we never filter images
during tests.
iscrowd: when traing setted as False, when val setted as True
"""
if download:
if path:
assert os.path.isdir(path), f'{path} is not dir'
path = download_coco(
'coco2017', split=split, target_dir=path, task='detection')
else:
path = download_coco('coco2017', split=split, task='detection')
else:
if path:
assert os.path.isdir(path), f'{path} is not dir'
path = check_data_exists(
target_dir=path, split=split, task='detection')
else:
raise KeyError('your path is None')
super(DetSourceCoco2017, self).__init__(
ann_file=path['ann_file'],
img_prefix=path['img_prefix'],
pipeline=pipeline,
test_mode=test_mode,
filter_empty_gt=filter_empty_gt,
classes=classes,
iscrowd=iscrowd)
@DATASOURCES.register_module
class DetSourceTinyPerson(DetSourceCoco):
"""
TINY PERSON data source
"""
CLASSES = ['sea_person', 'earth_person']
def __init__(self,
ann_file,
img_prefix,
pipeline,
test_mode=False,
filter_empty_gt=False,
classes=CLASSES,
iscrowd=False):
"""
Args:
ann_file: Path of annotation file.
img_prefix: coco path prefix
test_mode (bool, optional): If set True, `self._filter_imgs` will not works.
filter_empty_gt (bool, optional): If set true, images without bounding
boxes of the dataset's classes will be filtered out. This option
only works when `test_mode=False`, i.e., we never filter images
during tests.
iscrowd: when traing setted as False, when val setted as True
"""
super(DetSourceTinyPerson, self).__init__(
ann_file=ann_file,
img_prefix=img_prefix,
pipeline=pipeline,
test_mode=test_mode,
filter_empty_gt=filter_empty_gt,
classes=classes,
iscrowd=iscrowd)