[Enhancement] Support loading GT for evaluation from multi-file backend (#867)
* support load gt for evaluation from multi-backend * move some code from get_gt_seg_maps to get_one_gt_seg_map * rename gt_seg_map_loader_conf to gt_seg_map_loader_cfg * fix doc str * rename get_one_gt_seg_map to get_gt_seg_map_by_idxpull/1801/head
parent
730f36cd8b
commit
4583dc1033
|
@ -12,7 +12,7 @@ from torch.utils.data import Dataset
|
|||
from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
|
||||
from mmseg.utils import get_root_logger
|
||||
from .builder import DATASETS
|
||||
from .pipelines import Compose
|
||||
from .pipelines import Compose, LoadAnnotations
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -66,6 +66,8 @@ class CustomDataset(Dataset):
|
|||
The palette of segmentation map. If None is given, and
|
||||
self.PALETTE is None, random palette will be generated.
|
||||
Default: None
|
||||
gt_seg_map_loader_cfg (dict, optional): build LoadAnnotations to
|
||||
load gt for evaluation, load from disk by default. Default: None.
|
||||
"""
|
||||
|
||||
CLASSES = None
|
||||
|
@ -84,7 +86,8 @@ class CustomDataset(Dataset):
|
|||
ignore_index=255,
|
||||
reduce_zero_label=False,
|
||||
classes=None,
|
||||
palette=None):
|
||||
palette=None,
|
||||
gt_seg_map_loader_cfg=None):
|
||||
self.pipeline = Compose(pipeline)
|
||||
self.img_dir = img_dir
|
||||
self.img_suffix = img_suffix
|
||||
|
@ -98,6 +101,10 @@ class CustomDataset(Dataset):
|
|||
self.label_map = None
|
||||
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
|
||||
classes, palette)
|
||||
self.gt_seg_map_loader = LoadAnnotations(
|
||||
) if gt_seg_map_loader_cfg is None else LoadAnnotations(
|
||||
**gt_seg_map_loader_cfg)
|
||||
|
||||
if test_mode:
|
||||
assert self.CLASSES is not None, \
|
||||
'`cls.CLASSES` or `classes` should be specified when testing'
|
||||
|
@ -232,6 +239,14 @@ class CustomDataset(Dataset):
|
|||
"""Place holder to format result to dataset specific output."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_gt_seg_map_by_idx(self, index):
|
||||
"""Get one ground truth segmentation map for evaluation."""
|
||||
ann_info = self.get_ann_info(index)
|
||||
results = dict(ann_info=ann_info)
|
||||
self.pre_pipeline(results)
|
||||
self.gt_seg_map_loader(results)
|
||||
return results['gt_semantic_seg']
|
||||
|
||||
def get_gt_seg_maps(self, efficient_test=None):
|
||||
"""Get ground truth segmentation maps for evaluation."""
|
||||
if efficient_test is not None:
|
||||
|
@ -240,11 +255,12 @@ class CustomDataset(Dataset):
|
|||
'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory '
|
||||
'friendly by default. ')
|
||||
|
||||
for img_info in self.img_infos:
|
||||
seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
|
||||
gt_seg_map = mmcv.imread(
|
||||
seg_map, flag='unchanged', backend='pillow')
|
||||
yield gt_seg_map
|
||||
for idx in range(len(self)):
|
||||
ann_info = self.get_ann_info(idx)
|
||||
results = dict(ann_info=ann_info)
|
||||
self.pre_pipeline(results)
|
||||
self.gt_seg_map_loader(results)
|
||||
yield results['gt_semantic_seg']
|
||||
|
||||
def pre_eval(self, preds, indices):
|
||||
"""Collect eval result from each iteration.
|
||||
|
@ -268,9 +284,7 @@ class CustomDataset(Dataset):
|
|||
pre_eval_results = []
|
||||
|
||||
for pred, index in zip(preds, indices):
|
||||
seg_map = osp.join(self.ann_dir,
|
||||
self.img_infos[index]['ann']['seg_map'])
|
||||
seg_map = mmcv.imread(seg_map, flag='unchanged', backend='pillow')
|
||||
seg_map = self.get_gt_seg_map_by_idx(index)
|
||||
pre_eval_results.append(
|
||||
intersect_and_union(pred, seg_map, len(self.CLASSES),
|
||||
self.ignore_index, self.label_map,
|
||||
|
|
Loading…
Reference in New Issue