From 5a7996db26fb70131675bc771866f86ce9f3599a Mon Sep 17 00:00:00 2001 From: uni19 Date: Wed, 15 Sep 2021 10:16:01 +0800 Subject: [PATCH] [Enhancement] Support loading GT for evaluation from multi-file backend (#867) * support load gt for evaluation from multi-backend * move some code from get_gt_seg_maps to get_one_gt_seg_map * rename gt_seg_map_loader_conf to gt_seg_map_loader_cfg * fix doc str * rename get_one_gt_seg_map to get_gt_seg_map_by_idx --- mmseg/datasets/custom.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/mmseg/datasets/custom.py b/mmseg/datasets/custom.py index 9b0efc6f0..23b347d34 100644 --- a/mmseg/datasets/custom.py +++ b/mmseg/datasets/custom.py @@ -12,7 +12,7 @@ from torch.utils.data import Dataset from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics from mmseg.utils import get_root_logger from .builder import DATASETS -from .pipelines import Compose +from .pipelines import Compose, LoadAnnotations @DATASETS.register_module() @@ -66,6 +66,8 @@ class CustomDataset(Dataset): The palette of segmentation map. If None is given, and self.PALETTE is None, random palette will be generated. Default: None + gt_seg_map_loader_cfg (dict, optional): build LoadAnnotations to + load gt for evaluation, load from disk by default. Default: None. """ CLASSES = None @@ -84,7 +86,8 @@ class CustomDataset(Dataset): ignore_index=255, reduce_zero_label=False, classes=None, - palette=None): + palette=None, + gt_seg_map_loader_cfg=None): self.pipeline = Compose(pipeline) self.img_dir = img_dir self.img_suffix = img_suffix @@ -98,6 +101,10 @@ class CustomDataset(Dataset): self.label_map = None self.CLASSES, self.PALETTE = self.get_classes_and_palette( classes, palette) + self.gt_seg_map_loader = LoadAnnotations( + ) if gt_seg_map_loader_cfg is None else LoadAnnotations( + **gt_seg_map_loader_cfg) + if test_mode: assert self.CLASSES is not None, \ '`cls.CLASSES` or `classes` should be specified when testing' @@ -232,6 +239,14 @@ class CustomDataset(Dataset): """Place holder to format result to dataset specific output.""" raise NotImplementedError + def get_gt_seg_map_by_idx(self, index): + """Get one ground truth segmentation map for evaluation.""" + ann_info = self.get_ann_info(index) + results = dict(ann_info=ann_info) + self.pre_pipeline(results) + self.gt_seg_map_loader(results) + return results['gt_semantic_seg'] + def get_gt_seg_maps(self, efficient_test=None): """Get ground truth segmentation maps for evaluation.""" if efficient_test is not None: @@ -240,11 +255,12 @@ class CustomDataset(Dataset): 'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory ' 'friendly by default. ') - for img_info in self.img_infos: - seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map']) - gt_seg_map = mmcv.imread( - seg_map, flag='unchanged', backend='pillow') - yield gt_seg_map + for idx in range(len(self)): + ann_info = self.get_ann_info(idx) + results = dict(ann_info=ann_info) + self.pre_pipeline(results) + self.gt_seg_map_loader(results) + yield results['gt_semantic_seg'] def pre_eval(self, preds, indices): """Collect eval result from each iteration. @@ -268,9 +284,7 @@ class CustomDataset(Dataset): pre_eval_results = [] for pred, index in zip(preds, indices): - seg_map = osp.join(self.ann_dir, - self.img_infos[index]['ann']['seg_map']) - seg_map = mmcv.imread(seg_map, flag='unchanged', backend='pillow') + seg_map = self.get_gt_seg_map_by_idx(index) pre_eval_results.append( intersect_and_union(pred, seg_map, len(self.CLASSES), self.ignore_index, self.label_map,