mirror of https://github.com/alibaba/EasyCV.git
204 lines
8.0 KiB
Python
204 lines
8.0 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import os
|
|
|
|
import numpy as np
|
|
from pycocotools.coco import COCO
|
|
from tqdm import tqdm
|
|
|
|
from easycv.datasets.registry import DATASOURCES
|
|
from easycv.datasets.utils.download_data.download_coco import (
|
|
check_data_exists, download_coco)
|
|
from easycv.utils.constant import CACHE_DIR
|
|
from .base import load_image
|
|
|
|
|
|
@DATASOURCES.register_module
|
|
class SegSourceCoco(object):
|
|
|
|
COCO_CLASSES = [
|
|
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
|
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
|
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
|
|
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
|
|
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
|
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
|
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
|
|
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
|
|
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
|
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
|
|
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
|
|
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
|
'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
|
]
|
|
|
|
def __init__(self,
|
|
ann_file,
|
|
img_prefix,
|
|
palette=None,
|
|
reduce_zero_label=False,
|
|
classes=COCO_CLASSES,
|
|
iscrowd=False) -> None:
|
|
"""
|
|
Args:
|
|
ann_file: Path of annotation file.
|
|
img_prefix: coco path prefix
|
|
reduce_zero_label (bool): whether to mark label zero as ignored
|
|
palette (Sequence[Sequence[int]]] | np.ndarray | None):
|
|
palette of segmentation map, if none, random palette will be generated
|
|
classes (str | list): classes list or file
|
|
iscrowd: when traing setted as False, when val setted as True
|
|
"""
|
|
|
|
self.ann_file = ann_file
|
|
self.img_prefile = img_prefix
|
|
self.iscrowd = iscrowd
|
|
self.reduce_zero_label = reduce_zero_label
|
|
if palette is not None:
|
|
self.PALETTE = palette
|
|
else:
|
|
self.PALETTE = self.get_random_palette()
|
|
|
|
self.seg = COCO(self.ann_file)
|
|
self.catIds = self.seg.getCatIds(catNms=classes)
|
|
self.imgIds = self._load_annotations(self.seg.getImgIds())
|
|
|
|
def _load_annotations(self, imgIds):
|
|
seg_imgIds = []
|
|
for imgId in tqdm(imgIds, desc='Scanning images'):
|
|
annIds = self.seg.getAnnIds(
|
|
imgIds=imgId, catIds=self.catIds, iscrowd=self.iscrowd)
|
|
anns = self.seg.loadAnns(annIds)
|
|
if len(anns):
|
|
seg_imgIds.append(imgId)
|
|
|
|
return seg_imgIds
|
|
|
|
def load_seg_map(self, gt_semantic_seg, reduce_zero_label):
|
|
|
|
# reduce zero_label
|
|
if reduce_zero_label:
|
|
# avoid using underflow conversion
|
|
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
|
gt_semantic_seg = gt_semantic_seg - 1
|
|
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
|
|
|
return gt_semantic_seg
|
|
|
|
def _parse_load_seg(self, ids):
|
|
annIds = self.seg.getAnnIds(
|
|
imgIds=ids, catIds=self.catIds, iscrowd=self.iscrowd)
|
|
anns = self.seg.loadAnns(annIds)
|
|
pre_cat_mask = self.seg.annToMask(anns[0])
|
|
mask = pre_cat_mask * (self.catIds.index(anns[0]['category_id']) + 1)
|
|
|
|
for ann in anns[1:]:
|
|
|
|
binary_mask = self.seg.annToMask(ann)
|
|
mask += binary_mask * (self.catIds.index(ann['category_id']) + 1)
|
|
mask_area = pre_cat_mask + binary_mask
|
|
bask_biny = mask_area == 2
|
|
|
|
mask[bask_biny] = self.catIds.index(ann['category_id']) + 1
|
|
mask_area[bask_biny] = 1
|
|
pre_cat_mask = mask_area
|
|
|
|
return self.load_seg_map(mask, self.reduce_zero_label)
|
|
|
|
def get_random_palette(self):
|
|
# Get random state before set seed, and restore
|
|
# random state later.
|
|
# It will prevent loss of randomness, as the palette
|
|
# may be different in each iteration if not specified.
|
|
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
|
state = np.random.get_state()
|
|
np.random.seed(42)
|
|
# random palette
|
|
palette = np.random.randint(0, 255, size=(len(self.COCO_CLASSES), 3))
|
|
np.random.set_state(state)
|
|
|
|
return palette
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.imgIds)
|
|
|
|
def __getitem__(self, idx):
|
|
imgId = self.imgIds[idx]
|
|
img = self.seg.loadImgs(imgId)[0]
|
|
id = img['id']
|
|
file_name = os.path.join(self.img_prefile, img['file_name'])
|
|
gt_semantic_seg = self._parse_load_seg(id)
|
|
result = {
|
|
'filename': file_name,
|
|
'gt_semantic_seg': gt_semantic_seg,
|
|
'img_fields': ['img'],
|
|
'seg_fields': ['gt_semantic_seg']
|
|
}
|
|
result.update(load_image(file_name))
|
|
|
|
return result
|
|
|
|
|
|
@DATASOURCES.register_module
|
|
class SegSourceCoco2017(SegSourceCoco):
|
|
COCO_CLASSES = [
|
|
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
|
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
|
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
|
|
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
|
|
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
|
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
|
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
|
|
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
|
|
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
|
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
|
|
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
|
|
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
|
'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
|
]
|
|
|
|
def __init__(self,
|
|
download=False,
|
|
split='train',
|
|
path=CACHE_DIR,
|
|
palette=None,
|
|
reduce_zero_label=False,
|
|
classes=COCO_CLASSES,
|
|
iscrowd=False,
|
|
**kwargs) -> None:
|
|
"""
|
|
Args:
|
|
path: This parameter is optional. If download is True and path is not provided,
|
|
a temporary directory is automatically created for downloading
|
|
download: If the value is True, the file is automatically downloaded to the path directory.
|
|
If False, automatic download is not supported and data in the path is used
|
|
split: train or val
|
|
reduce_zero_label (bool): whether to mark label zero as ignored
|
|
palette (Sequence[Sequence[int]]] | np.ndarray | None):
|
|
palette of segmentation map, if none, random palette will be generated
|
|
classes (str | list): classes list or file
|
|
iscrowd: when traing setted as False, when val setted as True
|
|
"""
|
|
|
|
if download:
|
|
if path:
|
|
assert os.path.isdir(path), f'{path} is not dir'
|
|
path = download_coco(
|
|
'coco2017', split=split, target_dir=path, task='detection')
|
|
else:
|
|
path = download_coco('coco2017', split=split, task='detection')
|
|
else:
|
|
if path:
|
|
assert os.path.isdir(path), f'{path} is not dir'
|
|
path = check_data_exists(
|
|
target_dir=path, split=split, task='detection')
|
|
else:
|
|
raise KeyError('your path is None')
|
|
super().__init__(
|
|
path['ann_file'],
|
|
path['img_prefix'],
|
|
palette=palette,
|
|
reduce_zero_label=reduce_zero_label,
|
|
classes=classes,
|
|
iscrowd=iscrowd)
|