[Enhancement] Support loading GT for evaluation from multi-file backend (#867)

* support load gt for evaluation from multi-backend

* move some code from get_gt_seg_maps to get_one_gt_seg_map

* rename gt_seg_map_loader_conf to gt_seg_map_loader_cfg

* fix doc str

* rename get_one_gt_seg_map to get_gt_seg_map_by_idx
This commit is contained in:
uni19 2021-09-15 10:16:01 +08:00 committed by GitHub
parent 56e18ba9aa
commit 5a7996db26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -12,7 +12,7 @@ from torch.utils.data import Dataset
from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
from mmseg.utils import get_root_logger from mmseg.utils import get_root_logger
from .builder import DATASETS from .builder import DATASETS
from .pipelines import Compose from .pipelines import Compose, LoadAnnotations
@DATASETS.register_module() @DATASETS.register_module()
@ -66,6 +66,8 @@ class CustomDataset(Dataset):
The palette of segmentation map. If None is given, and The palette of segmentation map. If None is given, and
self.PALETTE is None, random palette will be generated. self.PALETTE is None, random palette will be generated.
Default: None Default: None
gt_seg_map_loader_cfg (dict, optional): build LoadAnnotations to
load gt for evaluation, load from disk by default. Default: None.
""" """
CLASSES = None CLASSES = None
@ -84,7 +86,8 @@ class CustomDataset(Dataset):
ignore_index=255, ignore_index=255,
reduce_zero_label=False, reduce_zero_label=False,
classes=None, classes=None,
palette=None): palette=None,
gt_seg_map_loader_cfg=None):
self.pipeline = Compose(pipeline) self.pipeline = Compose(pipeline)
self.img_dir = img_dir self.img_dir = img_dir
self.img_suffix = img_suffix self.img_suffix = img_suffix
@ -98,6 +101,10 @@ class CustomDataset(Dataset):
self.label_map = None self.label_map = None
self.CLASSES, self.PALETTE = self.get_classes_and_palette( self.CLASSES, self.PALETTE = self.get_classes_and_palette(
classes, palette) classes, palette)
self.gt_seg_map_loader = LoadAnnotations(
) if gt_seg_map_loader_cfg is None else LoadAnnotations(
**gt_seg_map_loader_cfg)
if test_mode: if test_mode:
assert self.CLASSES is not None, \ assert self.CLASSES is not None, \
'`cls.CLASSES` or `classes` should be specified when testing' '`cls.CLASSES` or `classes` should be specified when testing'
@ -232,6 +239,14 @@ class CustomDataset(Dataset):
"""Place holder to format result to dataset specific output.""" """Place holder to format result to dataset specific output."""
raise NotImplementedError raise NotImplementedError
def get_gt_seg_map_by_idx(self, index):
"""Get one ground truth segmentation map for evaluation."""
ann_info = self.get_ann_info(index)
results = dict(ann_info=ann_info)
self.pre_pipeline(results)
self.gt_seg_map_loader(results)
return results['gt_semantic_seg']
def get_gt_seg_maps(self, efficient_test=None): def get_gt_seg_maps(self, efficient_test=None):
"""Get ground truth segmentation maps for evaluation.""" """Get ground truth segmentation maps for evaluation."""
if efficient_test is not None: if efficient_test is not None:
@ -240,11 +255,12 @@ class CustomDataset(Dataset):
'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory ' 'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory '
'friendly by default. ') 'friendly by default. ')
for img_info in self.img_infos: for idx in range(len(self)):
seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map']) ann_info = self.get_ann_info(idx)
gt_seg_map = mmcv.imread( results = dict(ann_info=ann_info)
seg_map, flag='unchanged', backend='pillow') self.pre_pipeline(results)
yield gt_seg_map self.gt_seg_map_loader(results)
yield results['gt_semantic_seg']
def pre_eval(self, preds, indices): def pre_eval(self, preds, indices):
"""Collect eval result from each iteration. """Collect eval result from each iteration.
@ -268,9 +284,7 @@ class CustomDataset(Dataset):
pre_eval_results = [] pre_eval_results = []
for pred, index in zip(preds, indices): for pred, index in zip(preds, indices):
seg_map = osp.join(self.ann_dir, seg_map = self.get_gt_seg_map_by_idx(index)
self.img_infos[index]['ann']['seg_map'])
seg_map = mmcv.imread(seg_map, flag='unchanged', backend='pillow')
pre_eval_results.append( pre_eval_results.append(
intersect_and_union(pred, seg_map, len(self.CLASSES), intersect_and_union(pred, seg_map, len(self.CLASSES),
self.ignore_index, self.label_map, self.ignore_index, self.label_map,