mirror of https://github.com/alibaba/EasyCV.git
196 lines
7.0 KiB
Python
196 lines
7.0 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import copy
|
|
import functools
|
|
import logging
|
|
from abc import abstractmethod
|
|
from multiprocessing import Pool, cpu_count
|
|
|
|
import cv2
|
|
import mmcv
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
from easycv.datasets.registry import DATASOURCES
|
|
from easycv.file.image import load_image as _load_img
|
|
from easycv.framework.errors import NotImplementedError, ValueError
|
|
|
|
|
|
def load_image(img_path):
|
|
img = _load_img(img_path, mode='BGR')
|
|
result = {
|
|
'img': img.astype(np.float32),
|
|
'img_shape': img.shape, # h, w, c
|
|
'ori_shape': img.shape,
|
|
}
|
|
return result
|
|
|
|
|
|
def load_seg_map(seg_path, reduce_zero_label):
|
|
gt_semantic_seg = _load_img(seg_path, mode='P')
|
|
# reduce zero_label
|
|
if reduce_zero_label:
|
|
# avoid using underflow conversion
|
|
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
|
gt_semantic_seg = gt_semantic_seg - 1
|
|
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
|
|
|
return {'gt_semantic_seg': gt_semantic_seg}
|
|
|
|
|
|
def build_sample(source_item, classes, parse_fn, load_img, reduce_zero_label):
|
|
"""Build sample info from source item.
|
|
Args:
|
|
source_item: item of source iterator
|
|
classes: classes list
|
|
parse_fn: parse function to parse source_item, only accepts two params: source_item and classes
|
|
load_img: load image or not, if true, cache all images in memory at init
|
|
"""
|
|
result_dict = parse_fn(source_item, classes)
|
|
|
|
if load_img:
|
|
result_dict.update(load_image(result_dict['filename']))
|
|
result_dict.update(
|
|
load_seg_map(result_dict['seg_filename'], reduce_zero_label))
|
|
|
|
return result_dict
|
|
|
|
|
|
@DATASOURCES.register_module
|
|
class SegSourceBase(object):
|
|
"""Data source for semantic segmentation.
|
|
classes (str | list): classes list or file
|
|
reduce_zero_label (bool): whether to mark label zero as ignored
|
|
palette (Sequence[Sequence[int]]] | np.ndarray | None):
|
|
palette of segmentation map, if none, random palette will be generated
|
|
num_processes: number of processes to parse samples
|
|
cache_at_init (bool): if set True, will cache in memory in __init__ for faster training
|
|
cache_on_the_fly (bool): if set True, will cache in memroy during training
|
|
"""
|
|
CLASSES = None
|
|
PALETTE = None
|
|
|
|
def __init__(self,
|
|
classes=None,
|
|
reduce_zero_label=False,
|
|
palette=None,
|
|
parse_fn=None,
|
|
num_processes=int(cpu_count() / 2),
|
|
cache_at_init=False,
|
|
cache_on_the_fly=False):
|
|
|
|
if classes is not None:
|
|
self.CLASSES = classes
|
|
if palette is not None:
|
|
self.PALETTE = palette
|
|
|
|
self.reduce_zero_label = reduce_zero_label
|
|
self.cache_at_init = cache_at_init
|
|
self.cache_on_the_fly = cache_on_the_fly
|
|
self.num_processes = num_processes
|
|
|
|
if self.cache_at_init and self.cache_on_the_fly:
|
|
raise ValueError(
|
|
'Only one of `cache_on_the_fly` and `cache_at_init` can be True!'
|
|
)
|
|
assert isinstance(self.CLASSES, (str, tuple, list))
|
|
if isinstance(self.CLASSES, str):
|
|
self.CLASSES = mmcv.list_from_file(classes)
|
|
if self.PALETTE is None:
|
|
self.PALETTE = self.get_random_palette()
|
|
|
|
source_iter = self.get_source_iterator()
|
|
|
|
process_fn = functools.partial(
|
|
build_sample,
|
|
parse_fn=parse_fn,
|
|
classes=self.CLASSES,
|
|
load_img=cache_at_init == True,
|
|
reduce_zero_label=self.reduce_zero_label)
|
|
self.samples_list = self.build_samples(
|
|
source_iter, process_fn=process_fn)
|
|
self.num_samples = len(self.samples_list)
|
|
# An error will be raised if failed to load _max_retry_num times in a row
|
|
self._max_retry_num = self.num_samples
|
|
self._retry_count = 0
|
|
|
|
@abstractmethod
|
|
def get_source_iterator():
|
|
"""Return data list iterator, source iterator will be passed to parse_fn,
|
|
and parse_fn will receive params of item of source iter and classes for parsing.
|
|
What does parse_fn need, what does source iterator returns.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def build_samples(self, iterable, process_fn):
|
|
samples_list = []
|
|
with Pool(processes=self.num_processes) as p:
|
|
with tqdm(total=len(iterable), desc='Scanning images') as pbar:
|
|
for _, result_dict in enumerate(
|
|
p.imap_unordered(process_fn, iterable)):
|
|
if result_dict:
|
|
samples_list.append(result_dict)
|
|
pbar.update()
|
|
|
|
return samples_list
|
|
|
|
def __getitem__(self, idx):
|
|
result_dict = self.samples_list[idx]
|
|
load_success = True
|
|
try:
|
|
# avoid data cache from taking up too much memory
|
|
if not self.cache_at_init and not self.cache_on_the_fly:
|
|
result_dict = copy.deepcopy(result_dict)
|
|
|
|
if not self.cache_at_init:
|
|
if result_dict.get('img', None) is None:
|
|
result_dict.update(load_image(result_dict['filename']))
|
|
if result_dict.get('gt_semantic_seg', None) is None:
|
|
result_dict.update(
|
|
load_seg_map(
|
|
result_dict['seg_filename'],
|
|
reduce_zero_label=self.reduce_zero_label))
|
|
if self.cache_on_the_fly:
|
|
self.samples_list[idx] = result_dict
|
|
result_dict = self.post_process_fn(copy.deepcopy(result_dict))
|
|
self._retry_count = 0
|
|
except Exception as e:
|
|
logging.warning(e)
|
|
load_success = False
|
|
|
|
if not load_success:
|
|
logging.warning(
|
|
'Something wrong with current sample %s,Try load next sample...'
|
|
% result_dict.get('filename', ''))
|
|
self._retry_count += 1
|
|
if self._retry_count >= self._max_retry_num:
|
|
raise ValueError('All samples failed to load!')
|
|
|
|
result_dict = self[(idx + 1) % self.num_samples]
|
|
|
|
return result_dict
|
|
|
|
def post_process_fn(self, result_dict):
|
|
if result_dict.get('img_fields', None) is None:
|
|
result_dict['img_fields'] = ['img']
|
|
if result_dict.get('seg_fields', None) is None:
|
|
result_dict['seg_fields'] = ['gt_semantic_seg']
|
|
|
|
return result_dict
|
|
|
|
def get_random_palette(self):
|
|
# Get random state before set seed, and restore
|
|
# random state later.
|
|
# It will prevent loss of randomness, as the palette
|
|
# may be different in each iteration if not specified.
|
|
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
|
state = np.random.get_state()
|
|
np.random.seed(42)
|
|
# random palette
|
|
palette = np.random.randint(0, 255, size=(len(self.CLASSES), 3))
|
|
np.random.set_state(state)
|
|
|
|
return palette
|
|
|
|
def __len__(self):
|
|
return len(self.samples_list)
|