EasyCV/easycv/datasets/segmentation/data_sources/raw.py

131 lines
4.8 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
from multiprocessing import cpu_count
from easycv.datasets.registry import DATASOURCES
from easycv.file import io
from .base import SegSourceBase
def parse_raw(source_item, classes):
img_path, seg_path = source_item
result = {'filename': img_path, 'seg_filename': seg_path}
return result
@DATASOURCES.register_module
class SegSourceRaw(SegSourceBase):
"""Data source for semantic segmentation.
data format is as follows:
├── data_dir
│ │ ├── images
│ │ │ ├── 1.jpg
│ │ │ ├── 2.jpg
│ │ │ ├── ...
│ │ ├── labels
│ │ │ ├── 1.png
│ │ │ ├── 2.png
│ │ │ ├── ...
Args:
img_root (str): images dir path
label_root (str): labels dir path
split (str, optional): Split txt file. If split is specified, only
file with suffix in the splits will be loaded. Otherwise, all
images in img_root/label_root will be loaded.
classes (str | list): classes list or file
img_suffix (str): image file suffix
label_suffix (str): label file suffix
reduce_zero_label (bool): whether to mark label zero as ignored
palette (Sequence[Sequence[int]]] | np.ndarray | None):
palette of segmentation map, if none, random palette will be generated
cache_at_init (bool): if set True, will cache in memory in __init__ for faster training
cache_on_the_fly (bool): if set True, will cache in memroy during training
"""
def __init__(self,
img_root=None,
label_root=None,
split=None,
classes=None,
img_suffix='.jpg',
label_suffix='.png',
reduce_zero_label=False,
palette=None,
num_processes=int(cpu_count() / 2),
cache_at_init=False,
cache_on_the_fly=False):
self.img_root = img_root
self.label_root = label_root
self.split = split
self.classes = classes
self.PALETTE = palette
self.img_suffix = img_suffix
self.label_suffix = label_suffix
if isinstance(self.img_suffix, str):
self.img_suffix = [self.img_suffix]
if isinstance(label_suffix, str):
self.label_suffix = [self.label_suffix]
assert isinstance(self.img_suffix, list)
assert isinstance(self.label_suffix, list)
super(SegSourceRaw, self).__init__(
classes=classes,
reduce_zero_label=reduce_zero_label,
palette=palette,
parse_fn=parse_raw,
num_processes=num_processes,
cache_at_init=cache_at_init,
cache_on_the_fly=cache_on_the_fly)
def get_source_iterator(self):
if self.split is not None:
with io.open(self.split, 'r') as f:
lines = f.readlines()
self.img_files = []
for line in lines:
find = False
for img_suf in self.img_suffix:
filename = os.path.join(self.img_root,
line.strip() + img_suf)
if os.path.exists(filename):
self.img_files.append(filename)
find = True
if not find:
logging.warning('Not find file: %s with suffix %s' %
(line, self.img_suffix))
else:
self.img_files = [
os.path.join(self.img_root, i)
for i in io.listdir(self.img_root, recursive=True)
if os.path.splitext(i)[-1].lower() in self.img_suffix
]
self.label_files = []
for img_path in self.img_files:
img_name = os.path.splitext(os.path.basename(img_path))[0]
find_label_path = False
for label_format in self.label_suffix:
lable_path = os.path.join(self.label_root,
img_name + label_format)
if io.exists(lable_path):
find_label_path = True
self.label_files.append(lable_path)
break
if not find_label_path:
logging.warning(
'Not find label file %s for img: %s, skip the sample!' %
(lable_path, img_path))
self.img_files.remove(img_path)
assert len(self.img_files) == len(self.label_files)
assert len(
self.img_files) > 0, 'No samples found in %s' % self.img_root
return list(zip(self.img_files, self.label_files))