Use img_prefix and seg_prefix for loading ()

* Use img_prefix and seg_prefix for loading

* flake8

* Fix split
pull/1801/head
David de la Iglesia Castro 2020-09-24 18:48:16 +02:00 committed by GitHub
parent 588a2c036a
commit 51e4cdefc5
1 changed files with 8 additions and 8 deletions
mmseg/datasets

View File

@ -131,19 +131,16 @@ class CustomDataset(Dataset):
with open(split) as f: with open(split) as f:
for line in f: for line in f:
img_name = line.strip() img_name = line.strip()
img_file = osp.join(img_dir, img_name + img_suffix) img_info = dict(filename=img_name + img_suffix)
img_info = dict(filename=img_file)
if ann_dir is not None: if ann_dir is not None:
seg_map = osp.join(ann_dir, img_name + seg_map_suffix) seg_map = img_name + seg_map_suffix
img_info['ann'] = dict(seg_map=seg_map) img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info) img_infos.append(img_info)
else: else:
for img in mmcv.scandir(img_dir, img_suffix, recursive=True): for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
img_file = osp.join(img_dir, img) img_info = dict(filename=img)
img_info = dict(filename=img_file)
if ann_dir is not None: if ann_dir is not None:
seg_map = osp.join(ann_dir, seg_map = img.replace(img_suffix, seg_map_suffix)
img.replace(img_suffix, seg_map_suffix))
img_info['ann'] = dict(seg_map=seg_map) img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info) img_infos.append(img_info)
@ -165,6 +162,8 @@ class CustomDataset(Dataset):
def pre_pipeline(self, results): def pre_pipeline(self, results):
"""Prepare results dict for pipeline.""" """Prepare results dict for pipeline."""
results['seg_fields'] = [] results['seg_fields'] = []
results['img_prefix'] = self.img_dir
results['seg_prefix'] = self.ann_dir
if self.custom_classes: if self.custom_classes:
results['label_map'] = self.label_map results['label_map'] = self.label_map
@ -225,8 +224,9 @@ class CustomDataset(Dataset):
"""Get ground truth segmentation maps for evaluation.""" """Get ground truth segmentation maps for evaluation."""
gt_seg_maps = [] gt_seg_maps = []
for img_info in self.img_infos: for img_info in self.img_infos:
seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
gt_seg_map = mmcv.imread( gt_seg_map = mmcv.imread(
img_info['ann']['seg_map'], flag='unchanged', backend='pillow') seg_map, flag='unchanged', backend='pillow')
# modify if custom classes # modify if custom classes
if self.label_map is not None: if self.label_map is not None:
for old_id, new_id in self.label_map.items(): for old_id, new_id in self.label_map.items():