mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
130 lines
5.0 KiB
Python
130 lines
5.0 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import copy
|
|
import logging
|
|
import os
|
|
import subprocess
|
|
|
|
import numpy as np
|
|
|
|
from easycv.datasets.registry import DATASOURCES
|
|
from easycv.file import io
|
|
from easycv.file.image import load_image as _load_img
|
|
from .raw import SegSourceRaw
|
|
|
|
try:
|
|
import cityscapesscripts.helpers.labels as CSLabels
|
|
except ModuleNotFoundError as e:
|
|
res = subprocess.call('pip install cityscapesscripts', shell=True)
|
|
if res != 0:
|
|
info_string = (
|
|
'\n\nAuto install failed! Please install cityscapesscripts with the following commands :\n'
|
|
'\t`pip install cityscapesscripts`\n')
|
|
raise ModuleNotFoundError(info_string)
|
|
|
|
|
|
def load_seg_map_cityscape(seg_path, reduce_zero_label):
|
|
gt_semantic_seg = _load_img(seg_path, mode='P')
|
|
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
|
for labels in CSLabels.labels:
|
|
gt_semantic_seg_copy[gt_semantic_seg == labels.id] = labels.trainId
|
|
|
|
return {'gt_semantic_seg': gt_semantic_seg_copy}
|
|
|
|
|
|
@DATASOURCES.register_module
|
|
class SegSourceCityscapes(SegSourceRaw):
|
|
"""Cityscapes datasource
|
|
"""
|
|
CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
|
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
|
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
|
'bicycle')
|
|
|
|
PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
|
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
|
|
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
|
|
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
|
|
[0, 80, 100], [0, 0, 230], [119, 11, 32]]
|
|
|
|
def __init__(self,
|
|
img_suffix='_leftImg8bit.png',
|
|
label_suffix='_gtFine_labelIds.png',
|
|
**kwargs):
|
|
super(SegSourceCityscapes, self).__init__(
|
|
img_suffix=img_suffix, label_suffix=label_suffix, **kwargs)
|
|
|
|
def __getitem__(self, idx):
|
|
result_dict = self.samples_list[idx]
|
|
load_success = True
|
|
try:
|
|
# avoid data cache from taking up too much memory
|
|
if not self.cache_at_init and not self.cache_on_the_fly:
|
|
result_dict = copy.deepcopy(result_dict)
|
|
|
|
if not self.cache_at_init:
|
|
if result_dict.get('img', None) is None:
|
|
img = _load_img(result_dict['filename'], mode='BGR')
|
|
result = {
|
|
'img': img.astype(np.float32),
|
|
'img_shape': img.shape, # h, w, c
|
|
'ori_shape': img.shape,
|
|
}
|
|
result_dict.update(result)
|
|
if result_dict.get('gt_semantic_seg', None) is None:
|
|
result_dict.update(
|
|
load_seg_map_cityscape(
|
|
result_dict['seg_filename'],
|
|
reduce_zero_label=self.reduce_zero_label))
|
|
if self.cache_on_the_fly:
|
|
self.samples_list[idx] = result_dict
|
|
result_dict = self.post_process_fn(copy.deepcopy(result_dict))
|
|
self._retry_count = 0
|
|
except Exception as e:
|
|
logging.warning(e)
|
|
load_success = False
|
|
|
|
if not load_success:
|
|
logging.warning(
|
|
'Something wrong with current sample %s,Try load next sample...'
|
|
% result_dict.get('filename', ''))
|
|
self._retry_count += 1
|
|
if self._retry_count >= self._max_retry_num:
|
|
raise ValueError('All samples failed to load!')
|
|
|
|
result_dict = self[(idx + 1) % self.num_samples]
|
|
|
|
return result_dict
|
|
|
|
def get_source_iterator(self):
|
|
|
|
self.img_files = [
|
|
os.path.join(self.img_root, i)
|
|
for i in io.listdir(self.img_root, recursive=True)
|
|
if i.endswith(self.img_suffix[0])
|
|
]
|
|
|
|
self.label_files = []
|
|
for img_path in self.img_files:
|
|
self.img_root = os.path.join(self.img_root, '')
|
|
img_name = img_path.replace(self.img_root,
|
|
'')[:-len(self.img_suffix[0])]
|
|
find_label_path = False
|
|
for label_format in self.label_suffix:
|
|
lable_path = os.path.join(self.label_root,
|
|
img_name + label_format)
|
|
if io.exists(lable_path):
|
|
find_label_path = True
|
|
self.label_files.append(lable_path)
|
|
break
|
|
if not find_label_path:
|
|
logging.warning(
|
|
'Not find label file %s for img: %s, skip the sample!' %
|
|
(lable_path, img_path))
|
|
self.img_files.remove(img_path)
|
|
|
|
assert len(self.img_files) == len(self.label_files)
|
|
assert len(
|
|
self.img_files) > 0, 'No samples found in %s' % self.img_root
|
|
|
|
return list(zip(self.img_files, self.label_files))
|