EasyCV/easycv/datasets/segmentation/data_sources/base.py

196 lines
7.0 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import functools
import logging
from abc import abstractmethod
from multiprocessing import Pool, cpu_count
import cv2
import mmcv
import numpy as np
from tqdm import tqdm
from easycv.datasets.registry import DATASOURCES
from easycv.file.image import load_image as _load_img
from easycv.framework.errors import NotImplementedError, ValueError
def load_image(img_path):
img = _load_img(img_path, mode='BGR')
result = {
'img': img.astype(np.float32),
'img_shape': img.shape, # h, w, c
'ori_shape': img.shape,
}
return result
def load_seg_map(seg_path, reduce_zero_label):
gt_semantic_seg = _load_img(seg_path, mode='P')
# reduce zero_label
if reduce_zero_label:
# avoid using underflow conversion
gt_semantic_seg[gt_semantic_seg == 0] = 255
gt_semantic_seg = gt_semantic_seg - 1
gt_semantic_seg[gt_semantic_seg == 254] = 255
return {'gt_semantic_seg': gt_semantic_seg}
def build_sample(source_item, classes, parse_fn, load_img, reduce_zero_label):
"""Build sample info from source item.
Args:
source_item: item of source iterator
classes: classes list
parse_fn: parse function to parse source_item, only accepts two params: source_item and classes
load_img: load image or not, if true, cache all images in memory at init
"""
result_dict = parse_fn(source_item, classes)
if load_img:
result_dict.update(load_image(result_dict['filename']))
result_dict.update(
load_seg_map(result_dict['seg_filename'], reduce_zero_label))
return result_dict
@DATASOURCES.register_module
class SegSourceBase(object):
"""Data source for semantic segmentation.
classes (str | list): classes list or file
reduce_zero_label (bool): whether to mark label zero as ignored
palette (Sequence[Sequence[int]]] | np.ndarray | None):
palette of segmentation map, if none, random palette will be generated
num_processes: number of processes to parse samples
cache_at_init (bool): if set True, will cache in memory in __init__ for faster training
cache_on_the_fly (bool): if set True, will cache in memroy during training
"""
CLASSES = None
PALETTE = None
def __init__(self,
classes=None,
reduce_zero_label=False,
palette=None,
parse_fn=None,
num_processes=int(cpu_count() / 2),
cache_at_init=False,
cache_on_the_fly=False):
if classes is not None:
self.CLASSES = classes
if palette is not None:
self.PALETTE = palette
self.reduce_zero_label = reduce_zero_label
self.cache_at_init = cache_at_init
self.cache_on_the_fly = cache_on_the_fly
self.num_processes = num_processes
if self.cache_at_init and self.cache_on_the_fly:
raise ValueError(
'Only one of `cache_on_the_fly` and `cache_at_init` can be True!'
)
assert isinstance(self.CLASSES, (str, tuple, list))
if isinstance(self.CLASSES, str):
self.CLASSES = mmcv.list_from_file(classes)
if self.PALETTE is None:
self.PALETTE = self.get_random_palette()
source_iter = self.get_source_iterator()
process_fn = functools.partial(
build_sample,
parse_fn=parse_fn,
classes=self.CLASSES,
load_img=cache_at_init == True,
reduce_zero_label=self.reduce_zero_label)
self.samples_list = self.build_samples(
source_iter, process_fn=process_fn)
self.num_samples = len(self.samples_list)
# An error will be raised if failed to load _max_retry_num times in a row
self._max_retry_num = self.num_samples
self._retry_count = 0
@abstractmethod
def get_source_iterator():
"""Return data list iterator, source iterator will be passed to parse_fn,
and parse_fn will receive params of item of source iter and classes for parsing.
What does parse_fn need, what does source iterator returns.
"""
raise NotImplementedError
def build_samples(self, iterable, process_fn):
samples_list = []
with Pool(processes=self.num_processes) as p:
with tqdm(total=len(iterable), desc='Scanning images') as pbar:
for _, result_dict in enumerate(
p.imap_unordered(process_fn, iterable)):
if result_dict:
samples_list.append(result_dict)
pbar.update()
return samples_list
def __getitem__(self, idx):
result_dict = self.samples_list[idx]
load_success = True
try:
# avoid data cache from taking up too much memory
if not self.cache_at_init and not self.cache_on_the_fly:
result_dict = copy.deepcopy(result_dict)
if not self.cache_at_init:
if result_dict.get('img', None) is None:
result_dict.update(load_image(result_dict['filename']))
if result_dict.get('gt_semantic_seg', None) is None:
result_dict.update(
load_seg_map(
result_dict['seg_filename'],
reduce_zero_label=self.reduce_zero_label))
if self.cache_on_the_fly:
self.samples_list[idx] = result_dict
result_dict = self.post_process_fn(copy.deepcopy(result_dict))
self._retry_count = 0
except Exception as e:
logging.warning(e)
load_success = False
if not load_success:
logging.warning(
'Something wrong with current sample %s,Try load next sample...'
% result_dict.get('filename', ''))
self._retry_count += 1
if self._retry_count >= self._max_retry_num:
raise ValueError('All samples failed to load!')
result_dict = self[(idx + 1) % self.num_samples]
return result_dict
def post_process_fn(self, result_dict):
if result_dict.get('img_fields', None) is None:
result_dict['img_fields'] = ['img']
if result_dict.get('seg_fields', None) is None:
result_dict['seg_fields'] = ['gt_semantic_seg']
return result_dict
def get_random_palette(self):
# Get random state before set seed, and restore
# random state later.
# It will prevent loss of randomness, as the palette
# may be different in each iteration if not specified.
# See: https://github.com/open-mmlab/mmdetection/issues/5844
state = np.random.get_state()
np.random.seed(42)
# random palette
palette = np.random.randint(0, 255, size=(len(self.CLASSES), 3))
np.random.set_state(state)
return palette
def __len__(self):
return len(self.samples_list)