yhq 26cd12ab42
[features] Support stdc (#284)
* add stdc semantic segmentation algorithm
2023-02-16 14:00:59 +08:00

130 lines
5.0 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import logging
import os
import subprocess
import numpy as np
from easycv.datasets.registry import DATASOURCES
from easycv.file import io
from easycv.file.image import load_image as _load_img
from .raw import SegSourceRaw
try:
import cityscapesscripts.helpers.labels as CSLabels
except ModuleNotFoundError as e:
res = subprocess.call('pip install cityscapesscripts', shell=True)
if res != 0:
info_string = (
'\n\nAuto install failed! Please install cityscapesscripts with the following commands :\n'
'\t`pip install cityscapesscripts`\n')
raise ModuleNotFoundError(info_string)
def load_seg_map_cityscape(seg_path, reduce_zero_label):
gt_semantic_seg = _load_img(seg_path, mode='P')
gt_semantic_seg_copy = gt_semantic_seg.copy()
for labels in CSLabels.labels:
gt_semantic_seg_copy[gt_semantic_seg == labels.id] = labels.trainId
return {'gt_semantic_seg': gt_semantic_seg_copy}
@DATASOURCES.register_module
class SegSourceCityscapes(SegSourceRaw):
"""Cityscapes datasource
"""
CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle')
PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
[0, 80, 100], [0, 0, 230], [119, 11, 32]]
def __init__(self,
img_suffix='_leftImg8bit.png',
label_suffix='_gtFine_labelIds.png',
**kwargs):
super(SegSourceCityscapes, self).__init__(
img_suffix=img_suffix, label_suffix=label_suffix, **kwargs)
def __getitem__(self, idx):
result_dict = self.samples_list[idx]
load_success = True
try:
# avoid data cache from taking up too much memory
if not self.cache_at_init and not self.cache_on_the_fly:
result_dict = copy.deepcopy(result_dict)
if not self.cache_at_init:
if result_dict.get('img', None) is None:
img = _load_img(result_dict['filename'], mode='BGR')
result = {
'img': img.astype(np.float32),
'img_shape': img.shape, # h, w, c
'ori_shape': img.shape,
}
result_dict.update(result)
if result_dict.get('gt_semantic_seg', None) is None:
result_dict.update(
load_seg_map_cityscape(
result_dict['seg_filename'],
reduce_zero_label=self.reduce_zero_label))
if self.cache_on_the_fly:
self.samples_list[idx] = result_dict
result_dict = self.post_process_fn(copy.deepcopy(result_dict))
self._retry_count = 0
except Exception as e:
logging.warning(e)
load_success = False
if not load_success:
logging.warning(
'Something wrong with current sample %s,Try load next sample...'
% result_dict.get('filename', ''))
self._retry_count += 1
if self._retry_count >= self._max_retry_num:
raise ValueError('All samples failed to load!')
result_dict = self[(idx + 1) % self.num_samples]
return result_dict
def get_source_iterator(self):
self.img_files = [
os.path.join(self.img_root, i)
for i in io.listdir(self.img_root, recursive=True)
if i.endswith(self.img_suffix[0])
]
self.label_files = []
for img_path in self.img_files:
self.img_root = os.path.join(self.img_root, '')
img_name = img_path.replace(self.img_root,
'')[:-len(self.img_suffix[0])]
find_label_path = False
for label_format in self.label_suffix:
lable_path = os.path.join(self.label_root,
img_name + label_format)
if io.exists(lable_path):
find_label_path = True
self.label_files.append(lable_path)
break
if not find_label_path:
logging.warning(
'Not find label file %s for img: %s, skip the sample!' %
(lable_path, img_path))
self.img_files.remove(img_path)
assert len(self.img_files) == len(self.label_files)
assert len(
self.img_files) > 0, 'No samples found in %s' % self.img_root
return list(zip(self.img_files, self.label_files))