[Enhancement] Support loading GT for evaluation from multi-file backend (#867)

* support load gt for evaluation from multi-backend

* move some code from get_gt_seg_maps to get_one_gt_seg_map

* rename gt_seg_map_loader_conf to gt_seg_map_loader_cfg

* fix doc str

* rename get_one_gt_seg_map to get_gt_seg_map_by_idx
pull/1801/head
uni19 2021-09-15 10:16:01 +08:00 committed by GitHub
parent 730f36cd8b
commit 4583dc1033
1 changed files with 24 additions and 10 deletions

View File

@ -12,7 +12,7 @@ from torch.utils.data import Dataset
from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
from mmseg.utils import get_root_logger
from .builder import DATASETS
from .pipelines import Compose
from .pipelines import Compose, LoadAnnotations
@DATASETS.register_module()
@ -66,6 +66,8 @@ class CustomDataset(Dataset):
The palette of segmentation map. If None is given, and
self.PALETTE is None, random palette will be generated.
Default: None
gt_seg_map_loader_cfg (dict, optional): build LoadAnnotations to
load gt for evaluation, load from disk by default. Default: None.
"""
CLASSES = None
@ -84,7 +86,8 @@ class CustomDataset(Dataset):
ignore_index=255,
reduce_zero_label=False,
classes=None,
palette=None):
palette=None,
gt_seg_map_loader_cfg=None):
self.pipeline = Compose(pipeline)
self.img_dir = img_dir
self.img_suffix = img_suffix
@ -98,6 +101,10 @@ class CustomDataset(Dataset):
self.label_map = None
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
classes, palette)
self.gt_seg_map_loader = LoadAnnotations(
) if gt_seg_map_loader_cfg is None else LoadAnnotations(
**gt_seg_map_loader_cfg)
if test_mode:
assert self.CLASSES is not None, \
'`cls.CLASSES` or `classes` should be specified when testing'
@ -232,6 +239,14 @@ class CustomDataset(Dataset):
"""Place holder to format result to dataset specific output."""
raise NotImplementedError
def get_gt_seg_map_by_idx(self, index):
"""Get one ground truth segmentation map for evaluation."""
ann_info = self.get_ann_info(index)
results = dict(ann_info=ann_info)
self.pre_pipeline(results)
self.gt_seg_map_loader(results)
return results['gt_semantic_seg']
def get_gt_seg_maps(self, efficient_test=None):
"""Get ground truth segmentation maps for evaluation."""
if efficient_test is not None:
@ -240,11 +255,12 @@ class CustomDataset(Dataset):
'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory '
'friendly by default. ')
for img_info in self.img_infos:
seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
gt_seg_map = mmcv.imread(
seg_map, flag='unchanged', backend='pillow')
yield gt_seg_map
for idx in range(len(self)):
ann_info = self.get_ann_info(idx)
results = dict(ann_info=ann_info)
self.pre_pipeline(results)
self.gt_seg_map_loader(results)
yield results['gt_semantic_seg']
def pre_eval(self, preds, indices):
"""Collect eval result from each iteration.
@ -268,9 +284,7 @@ class CustomDataset(Dataset):
pre_eval_results = []
for pred, index in zip(preds, indices):
seg_map = osp.join(self.ann_dir,
self.img_infos[index]['ann']['seg_map'])
seg_map = mmcv.imread(seg_map, flag='unchanged', backend='pillow')
seg_map = self.get_gt_seg_map_by_idx(index)
pre_eval_results.append(
intersect_and_union(pred, seg_map, len(self.CLASSES),
self.ignore_index, self.label_map,