2022-12-02 10:57:23 +08:00
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
|
|
|
|
import copy
|
|
|
|
import logging
|
|
|
|
import os
|
|
|
|
from multiprocessing import Pool, cpu_count
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
import mmcv
|
|
|
|
import numpy as np
|
|
|
|
from scipy.io import loadmat
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
from easycv.datasets.registry import DATASOURCES
|
|
|
|
from easycv.file import io
|
|
|
|
from easycv.file.image import load_image as _load_img
|
|
|
|
from .base import SegSourceBase
|
|
|
|
from .raw import parse_raw
|
|
|
|
|
|
|
|
|
|
|
|
@DATASOURCES.register_module
|
|
|
|
class SegSourceCocoStuff10k(SegSourceBase):
|
|
|
|
|
|
|
|
CLASSES = [
|
|
|
|
'unlabeled', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
|
|
|
|
'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
|
|
|
'street sign', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
|
|
|
|
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
|
|
|
|
'hat', 'backpack', 'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie',
|
|
|
|
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
|
|
|
|
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
|
|
|
'tennis racket', 'bottle', 'plate', 'wine glass', 'cup', 'fork',
|
|
|
|
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
|
|
|
|
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
|
|
|
'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window',
|
|
|
|
'desk', 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote',
|
|
|
|
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
|
|
|
'refrigerator', 'blender', 'book', 'clock', 'vase', 'scissors',
|
|
|
|
'teddy bear', 'hair drier', 'toothbrush', 'hair brush', 'banner',
|
|
|
|
'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
|
|
|
|
'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
|
|
|
|
'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
|
|
|
|
'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
|
|
|
|
'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower',
|
|
|
|
'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel',
|
|
|
|
'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal',
|
|
|
|
'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 'paper',
|
|
|
|
'pavement', 'pillow', 'plant-other', 'plastic', 'platform',
|
|
|
|
'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof',
|
|
|
|
'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper',
|
|
|
|
'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other',
|
|
|
|
'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable',
|
|
|
|
'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel',
|
|
|
|
'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
|
|
|
|
'window-blind', 'window-other', 'wood'
|
|
|
|
]
|
|
|
|
"""
|
|
|
|
```
|
|
|
|
data format is as follows:
|
|
|
|
|
|
|
|
├── data
|
|
|
|
│ ├── images
|
|
|
|
│ │ ├── 1.jpg
|
|
|
|
│ │ ├── 2.jpg
|
|
|
|
│ │ ├── ...
|
|
|
|
│ ├── annotations
|
|
|
|
│ │ ├── 1.mat
|
|
|
|
│ │ ├── 2.mat
|
|
|
|
│ │ ├── ...
|
|
|
|
| |—— imageLists
|
|
|
|
| |—— |—— train.txt
|
|
|
|
│ │ ├── ...
|
|
|
|
```
|
|
|
|
Example1:
|
|
|
|
data_source = SegSourceCocoStuff10k(
|
|
|
|
path='/your/data/imageLists/train.txt',
|
|
|
|
label_root='/your/data/annotation',
|
|
|
|
img_root='/your/data/images',
|
|
|
|
classes=${CLASSES}
|
|
|
|
)
|
|
|
|
Args:
|
|
|
|
path: annotation file
|
|
|
|
img_root (str): images dir path
|
|
|
|
label_root (str): labels dir path
|
|
|
|
classes (str | list): classes list or file
|
|
|
|
img_suffix (str): image file suffix
|
|
|
|
label_suffix (str): label file suffix
|
|
|
|
reduce_zero_label (bool): whether to mark label zero as ignored
|
|
|
|
palette (Sequence[Sequence[int]]] | np.ndarray | None):
|
|
|
|
palette of segmentation map, if none, random palette will be generated
|
|
|
|
cache_at_init (bool): if set True, will cache in memory in __init__ for faster training
|
|
|
|
cache_on_the_fly (bool): if set True, will cache in memroy during training
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
path,
|
|
|
|
img_root=None,
|
|
|
|
label_root=None,
|
|
|
|
classes=CLASSES,
|
|
|
|
img_suffix='.jpg',
|
|
|
|
label_suffix='.mat',
|
|
|
|
reduce_zero_label=False,
|
|
|
|
cache_at_init=False,
|
|
|
|
cache_on_the_fly=False,
|
|
|
|
palette=None,
|
|
|
|
num_processes=int(cpu_count() / 2)):
|
|
|
|
|
|
|
|
if classes is not None:
|
|
|
|
self.CLASSES = classes
|
|
|
|
if palette is not None:
|
|
|
|
self.PALETTE = palette
|
|
|
|
|
|
|
|
self.path = path
|
|
|
|
self.img_root = img_root
|
|
|
|
self.label_root = label_root
|
|
|
|
|
|
|
|
self.img_suffix = img_suffix
|
|
|
|
self.label_suffix = label_suffix
|
|
|
|
|
|
|
|
self.reduce_zero_label = reduce_zero_label
|
|
|
|
self.cache_at_init = cache_at_init
|
|
|
|
self.cache_on_the_fly = cache_on_the_fly
|
|
|
|
self.num_processes = num_processes
|
|
|
|
|
|
|
|
if self.cache_at_init and self.cache_on_the_fly:
|
|
|
|
raise ValueError(
|
|
|
|
'Only one of `cache_on_the_fly` and `cache_at_init` can be True!'
|
|
|
|
)
|
|
|
|
|
|
|
|
assert isinstance(self.CLASSES, (str, tuple, list))
|
|
|
|
if isinstance(self.CLASSES, str):
|
|
|
|
self.CLASSES = mmcv.list_from_file(classes)
|
|
|
|
if self.PALETTE is None:
|
|
|
|
self.PALETTE = self.get_random_palette()
|
|
|
|
|
|
|
|
source_iter = self.get_source_iterator()
|
|
|
|
|
|
|
|
self.samples_list = self.build_samples(
|
|
|
|
source_iter, process_fn=self.parse_mat)
|
|
|
|
self.num_samples = len(self.samples_list)
|
|
|
|
# An error will be raised if failed to load _max_retry_num times in a row
|
|
|
|
self._max_retry_num = self.num_samples
|
|
|
|
self._retry_count = 0
|
|
|
|
|
|
|
|
def parse_mat(self, source_item):
|
|
|
|
img_path, seg_path = source_item
|
|
|
|
result = {'filename': img_path, 'seg_filename': seg_path}
|
|
|
|
|
|
|
|
if self.cache_at_init:
|
|
|
|
result.update(self.load_image(img_path))
|
|
|
|
result.update(self.load_seg_map(seg_path, self.reduce_zero_label))
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
def load_seg_map(self, seg_path, reduce_zero_label):
|
|
|
|
gt_semantic_seg = loadmat(seg_path)['S']
|
|
|
|
# reduce zero_label
|
|
|
|
if reduce_zero_label:
|
|
|
|
# avoid using underflow conversion
|
|
|
|
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
|
|
|
gt_semantic_seg = gt_semantic_seg - 1
|
|
|
|
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
|
|
|
|
|
|
|
return {'gt_semantic_seg': gt_semantic_seg}
|
|
|
|
|
|
|
|
def load_image(self, img_path):
|
2022-12-08 16:41:06 +08:00
|
|
|
img = _load_img(img_path, mode='BGR')
|
2022-12-02 10:57:23 +08:00
|
|
|
result = {
|
|
|
|
'img': img.astype(np.float32),
|
|
|
|
'img_shape': img.shape, # h, w, c
|
|
|
|
'ori_shape': img.shape,
|
|
|
|
}
|
|
|
|
return result
|
|
|
|
|
|
|
|
def build_samples(self, iterable, process_fn):
|
|
|
|
samples_list = []
|
|
|
|
with Pool(processes=self.num_processes) as p:
|
|
|
|
with tqdm(total=len(iterable), desc='Scanning images') as pbar:
|
|
|
|
for _, result_dict in enumerate(
|
|
|
|
p.imap_unordered(process_fn, iterable)):
|
|
|
|
if result_dict:
|
|
|
|
samples_list.append(result_dict)
|
|
|
|
pbar.update()
|
|
|
|
|
|
|
|
return samples_list
|
|
|
|
|
|
|
|
def get_source_iterator(self):
|
|
|
|
|
|
|
|
with io.open(self.path, 'r') as f:
|
|
|
|
lines = f.read().splitlines()
|
|
|
|
|
|
|
|
img_files = []
|
|
|
|
label_files = []
|
|
|
|
for line in lines:
|
|
|
|
|
|
|
|
img_filename = os.path.join(self.img_root, line + self.img_suffix)
|
|
|
|
label_filename = os.path.join(self.label_root,
|
|
|
|
line + self.label_suffix)
|
|
|
|
|
|
|
|
if os.path.exists(img_filename) and os.path.exists(label_filename):
|
|
|
|
img_files.append(img_filename)
|
|
|
|
label_files.append(label_filename)
|
|
|
|
|
|
|
|
return list(zip(img_files, label_files))
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
result_dict = self.samples_list[idx]
|
|
|
|
load_success = True
|
|
|
|
try:
|
|
|
|
# avoid data cache from taking up too much memory
|
|
|
|
if not self.cache_at_init and not self.cache_on_the_fly:
|
|
|
|
result_dict = copy.deepcopy(result_dict)
|
|
|
|
|
|
|
|
if not self.cache_at_init:
|
|
|
|
if result_dict.get('img', None) is None:
|
|
|
|
result_dict.update(
|
|
|
|
self.load_image(result_dict['filename']))
|
|
|
|
if result_dict.get('gt_semantic_seg', None) is None:
|
|
|
|
result_dict.update(
|
|
|
|
self.load_seg_map(
|
|
|
|
result_dict['seg_filename'],
|
|
|
|
reduce_zero_label=self.reduce_zero_label))
|
|
|
|
if self.cache_on_the_fly:
|
|
|
|
self.samples_list[idx] = result_dict
|
|
|
|
result_dict = self.post_process_fn(copy.deepcopy(result_dict))
|
|
|
|
self._retry_count = 0
|
|
|
|
except Exception as e:
|
|
|
|
logging.warning(e)
|
|
|
|
load_success = False
|
|
|
|
|
|
|
|
if not load_success:
|
|
|
|
logging.warning(
|
|
|
|
'Something wrong with current sample %s,Try load next sample...'
|
|
|
|
% result_dict.get('filename', ''))
|
|
|
|
self._retry_count += 1
|
|
|
|
if self._retry_count >= self._max_retry_num:
|
|
|
|
raise ValueError('All samples failed to load!')
|
|
|
|
|
|
|
|
result_dict = self[(idx + 1) % self.num_samples]
|
|
|
|
|
|
|
|
return result_dict
|
|
|
|
|
|
|
|
|
|
|
|
@DATASOURCES.register_module
|
|
|
|
class SegSourceCocoStuff164k(SegSourceBase):
|
|
|
|
CLASSES = [
|
|
|
|
'unlabeled', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
|
|
|
|
'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
|
|
|
'street sign', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
|
|
|
|
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
|
|
|
|
'hat', 'backpack', 'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie',
|
|
|
|
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
|
|
|
|
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
|
|
|
'tennis racket', 'bottle', 'plate', 'wine glass', 'cup', 'fork',
|
|
|
|
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
|
|
|
|
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
|
|
|
'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window',
|
|
|
|
'desk', 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote',
|
|
|
|
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
|
|
|
'refrigerator', 'blender', 'book', 'clock', 'vase', 'scissors',
|
|
|
|
'teddy bear', 'hair drier', 'toothbrush', 'hair brush', 'banner',
|
|
|
|
'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
|
|
|
|
'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
|
|
|
|
'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
|
|
|
|
'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
|
|
|
|
'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower',
|
|
|
|
'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel',
|
|
|
|
'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal',
|
|
|
|
'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 'paper',
|
|
|
|
'pavement', 'pillow', 'plant-other', 'plastic', 'platform',
|
|
|
|
'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof',
|
|
|
|
'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper',
|
|
|
|
'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other',
|
|
|
|
'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable',
|
|
|
|
'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel',
|
|
|
|
'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
|
|
|
|
'window-blind', 'window-other', 'wood'
|
|
|
|
]
|
|
|
|
"""Data source for semantic segmentation.
|
|
|
|
data format is as follows:
|
|
|
|
|
|
|
|
├── data
|
|
|
|
│ │ ├── images
|
|
|
|
│ │ │ ├── 1.jpg
|
|
|
|
│ │ │ ├── 2.jpg
|
|
|
|
│ │ │ ├── ...
|
|
|
|
│ │ ├── labels
|
|
|
|
│ │ │ ├── 1.png
|
|
|
|
│ │ │ ├── 2.png
|
|
|
|
│ │ │ ├── ...
|
|
|
|
Example1:
|
|
|
|
data_source = SegSourceCocoStuff10k(
|
|
|
|
label_root='/your/data/labels',
|
|
|
|
img_root='/your/data/images',
|
|
|
|
classes=${CLASSES}
|
|
|
|
)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
img_root (str): images dir path
|
|
|
|
label_root (str): labels dir path
|
|
|
|
classes (str | list): classes list or file
|
|
|
|
img_suffix (str): image file suffix
|
|
|
|
label_suffix (str): label file suffix
|
|
|
|
reduce_zero_label (bool): whether to mark label zero as ignored
|
|
|
|
palette (Sequence[Sequence[int]]] | np.ndarray | None):
|
|
|
|
palette of segmentation map, if none, random palette will be generated
|
|
|
|
cache_at_init (bool): if set True, will cache in memory in __init__ for faster training
|
|
|
|
cache_on_the_fly (bool): if set True, will cache in memroy during training
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
img_root,
|
|
|
|
label_root,
|
|
|
|
classes=CLASSES,
|
|
|
|
img_suffix='.jpg',
|
|
|
|
label_suffix='.png',
|
|
|
|
reduce_zero_label=False,
|
|
|
|
palette=None,
|
|
|
|
num_processes=int(cpu_count() / 2),
|
|
|
|
cache_at_init=False,
|
|
|
|
cache_on_the_fly=False,
|
|
|
|
**kwargs) -> None:
|
|
|
|
|
|
|
|
self.img_root = img_root
|
|
|
|
self.label_root = label_root
|
|
|
|
|
|
|
|
self.classes = classes
|
|
|
|
self.PALETTE = palette
|
|
|
|
self.img_suffix = img_suffix
|
|
|
|
self.label_suffix = label_suffix
|
|
|
|
|
|
|
|
assert (os.path.exists(self.img_root) and os.path.exists(self.label_root)), \
|
|
|
|
f'{self.label_root} or {self.img_root} is not exists'
|
|
|
|
|
|
|
|
super(SegSourceCocoStuff164k, self).__init__(
|
|
|
|
classes=classes,
|
|
|
|
reduce_zero_label=reduce_zero_label,
|
|
|
|
palette=palette,
|
|
|
|
parse_fn=parse_raw,
|
|
|
|
num_processes=num_processes,
|
|
|
|
cache_at_init=cache_at_init,
|
|
|
|
cache_on_the_fly=cache_on_the_fly)
|
|
|
|
|
|
|
|
def get_source_iterator(self):
|
|
|
|
|
|
|
|
label_files = []
|
|
|
|
img_files = []
|
|
|
|
|
|
|
|
label_list = os.listdir(self.label_root)
|
|
|
|
for tmp_img in label_list:
|
|
|
|
label_file = os.path.join(self.label_root, tmp_img)
|
|
|
|
img_file = os.path.join(
|
|
|
|
self.img_root,
|
|
|
|
tmp_img.replace(self.label_suffix, self.img_suffix))
|
|
|
|
|
|
|
|
if os.path.exists(label_file) and os.path.exists(img_file):
|
|
|
|
|
|
|
|
label_files.append(label_file)
|
|
|
|
img_files.append(img_file)
|
|
|
|
|
|
|
|
return list(zip(img_files, label_files))
|