# Copyright (c) OpenMMLab. All rights reserved. import numpy as np from xtcocotools.coco import COCO from easycv.datasets.registry import DATASOURCES, PIPELINES from easycv.datasets.shared.pipelines import Compose from easycv.framework.errors import TypeError from easycv.utils.registry import build_from_cfg @DATASOURCES.register_module class DetSourceCoco(object): """ coco data source """ def __init__(self, ann_file, img_prefix, pipeline, test_mode=False, filter_empty_gt=False, classes=None, iscrowd=False): """ Args: ann_file: Path of annotation file. img_prefix: coco path prefix test_mode (bool, optional): If set True, `self._filter_imgs` will not works. filter_empty_gt (bool, optional): If set true, images without bounding boxes of the dataset's classes will be filtered out. This option only works when `test_mode=False`, i.e., we never filter images during tests. iscrowd: when traing setted as False, when val setted as True """ self.ann_file = ann_file self.img_prefix = img_prefix self.filter_empty_gt = filter_empty_gt self.CLASSES = classes # load annotations (and proposals) self.data_infos = self.load_annotations(self.ann_file) self.test_mode = test_mode if not test_mode: valid_inds = self._filter_imgs() self.data_infos = [self.data_infos[i] for i in valid_inds] self._set_group_flag() self.iscrowd = iscrowd self.max_labels_num = 120 transforms = [] for transform in pipeline: if isinstance(transform, dict): transform = build_from_cfg(transform, PIPELINES) transforms.append(transform) elif callable(transform): transforms.append(transform) else: raise TypeError('transform must be callable or a dict') self.pipeline = Compose(transforms) def __len__(self): """Total number of samples of data.""" return len(self.data_infos) def load_annotations(self, ann_file): """Load annotation from COCO style annotation file. Args: ann_file (str): Path of annotation file. Returns: list[dict]: Annotation info from COCO api. """ self.coco = COCO(ann_file) # The order of returned `cat_ids` will not # change with the order of the CLASSES self.cat_ids = self.coco.getCatIds(catNms=self.CLASSES) self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} self.img_ids = self.coco.getImgIds() data_infos = [] total_ann_ids = [] for i in self.img_ids: info = self.coco.loadImgs([i])[0] info['filename'] = info['file_name'] data_infos.append(info) ann_ids = self.coco.getAnnIds(imgIds=[i]) total_ann_ids.extend(ann_ids) assert len(set(total_ann_ids)) == len( total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!" return data_infos def get_ann_info(self, idx): """Get COCO annotation by index. Args: idx (int): Index of data. Returns: dict: Annotation info of specified index. """ img_id = self.data_infos[idx]['id'] ann_ids = self.coco.getAnnIds(imgIds=[img_id]) ann_info = self.coco.loadAnns(ann_ids) return self._parse_ann_info(self.data_infos[idx], ann_info) def get_cat_ids(self, idx): """Get COCO category ids by index. Args: idx (int): Index of data. Returns: list[int]: All categories in the image of specified index. """ img_id = self.data_infos[idx]['id'] ann_ids = self.coco.getAnnIds(imgIds=[img_id]) ann_info = self.coco.loadAnns(ann_ids) return [ann['category_id'] for ann in ann_info] def _filter_imgs(self, min_size=32): """Filter images too small or without ground truths.""" valid_inds = [] # obtain images that contain annotation ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values()) # obtain images that contain annotations of the required categories ids_in_cat = set() for i, class_id in enumerate(self.cat_ids): ids_in_cat |= set(self.coco.catToImgs[class_id]) # merge the image id sets of the two conditions and use the merged set # to filter out images if self.filter_empty_gt=True ids_in_cat &= ids_with_ann valid_img_ids = [] for i, img_info in enumerate(self.data_infos): img_id = self.img_ids[i] if self.filter_empty_gt and img_id not in ids_in_cat: continue if min(img_info['width'], img_info['height']) >= min_size: valid_inds.append(i) valid_img_ids.append(img_id) self.img_ids = valid_img_ids return valid_inds def _set_group_flag(self): """Set flag according to image aspect ratio. Images with aspect ratio greater than 1 will be set as group 1, otherwise group 0. """ self.flag = np.zeros(len(self), dtype=np.uint8) for i in range(len(self)): img_info = self.data_infos[i] if img_info['width'] / img_info['height'] > 1: self.flag[i] = 1 def _parse_ann_info(self, img_info, ann_info): """Parse bbox and mask annotation. Args: ann_info (list[dict]): Annotation info of an image. with_mask (bool): Whether to parse mask annotations. Returns: dict: A dict containing the following keys: bboxes, bboxes_ignore,\ labels, masks, seg_map. "masks" are raw annotations and not \ decoded into binary masks. """ gt_bboxes = [] gt_labels = [] gt_bboxes_ignore = [] gt_masks_ann = [] groundtruth_is_crowd = [] for i, ann in enumerate(ann_info): if ann.get('ignore', False): continue if ann.get('iscrowd', False) and ( not self.iscrowd): # while training, skip iscrowd continue x1, y1, w, h = ann['bbox'] inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) if inter_w * inter_h == 0: continue if ann['area'] <= 0 or w < 1 or h < 1: continue if ann['category_id'] not in self.cat_ids: continue bbox = [x1, y1, x1 + w, y1 + h] if ann.get('iscrowd', False): gt_bboxes.append( bbox ) # add crowded gt bboxes when eval, but not needed in training gt_labels.append(self.cat2label[ann['category_id']]) gt_masks_ann.append(ann.get('segmentation', None)) gt_bboxes_ignore.append(bbox) groundtruth_is_crowd.append(1) else: gt_bboxes.append(bbox) gt_labels.append(self.cat2label[ann['category_id']]) gt_masks_ann.append(ann.get('segmentation', None)) groundtruth_is_crowd.append(0) if gt_bboxes: gt_bboxes = np.array(gt_bboxes, dtype=np.float32) gt_labels = np.array(gt_labels, dtype=np.int64) groundtruth_is_crowd = np.array( groundtruth_is_crowd, dtype=np.int8) else: gt_bboxes = np.zeros((0, 4), dtype=np.float32) gt_labels = np.array([], dtype=np.int64) groundtruth_is_crowd = np.array([], dtype=np.int8) if gt_bboxes_ignore: gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) else: gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) seg_map = img_info['filename'].replace('jpg', 'png') ann = dict( bboxes=gt_bboxes, labels=gt_labels, groundtruth_is_crowd=groundtruth_is_crowd, bboxes_ignore=gt_bboxes_ignore, masks=gt_masks_ann, seg_map=seg_map) return ann def xyxy2xywh(self, bbox): """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO evaluation. Args: bbox (numpy.ndarray): The bounding boxes, shape (4, ), in ``xyxy`` order. Returns: list[float]: The converted bounding boxes, in ``xywh`` order. """ _bbox = bbox.tolist() return [ _bbox[0], _bbox[1], _bbox[2] - _bbox[0], _bbox[3] - _bbox[1], ] def _proposal2json(self, results): """Convert proposal results to COCO json style.""" json_results = [] for idx in range(len(self)): img_id = self.img_ids[idx] bboxes = results[idx] for i in range(bboxes.shape[0]): data = dict() data['image_id'] = img_id data['bbox'] = self.xyxy2xywh(bboxes[i]) data['score'] = float(bboxes[i][4]) data['category_id'] = 1 json_results.append(data) return json_results def _det2json(self, results): """Convert detection results to COCO json style.""" json_results = [] for idx in range(len(self)): img_id = self.img_ids[idx] result = results[idx] for label in range(len(result)): bboxes = result[label] for i in range(bboxes.shape[0]): data = dict() data['image_id'] = img_id data['bbox'] = self.xyxy2xywh(bboxes[i]) data['score'] = float(bboxes[i][4]) data['category_id'] = self.cat_ids[label] json_results.append(data) return json_results def pre_pipeline(self, results): """Prepare results dict for pipeline.""" results['img_prefix'] = self.img_prefix results['bbox_fields'] = [] results['mask_fields'] = [] results['seg_fields'] = [] def prepare_train_img(self, idx): """Get training data and annotations after pipeline. Args: idx (int): Index of data. Returns: dict: Training data and annotation after pipeline with new keys \ introduced by pipeline. """ img_info = self.data_infos[idx] ann_info = self.get_ann_info(idx) results = dict(img_info=img_info, ann_info=ann_info) self.pre_pipeline(results) return self.pipeline(results) def __getitem__(self, idx): """Get training/test data after pipeline. Args: idx (int): Index of data. Returns: dict: Training/test data (with annotation if `test_mode` is set \ True). """ while True: data = self.prepare_train_img(idx) if data is None: idx = self._rand_another(idx) continue return data def _rand_another(self, idx): """Get another random index from the same group as the given index.""" pool = np.where(self.flag == self.flag[idx])[0] return np.random.choice(pool)