mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Enhancement] Support loading GT for evaluation from multi-file backend (#867)
* support load gt for evaluation from multi-backend * move some code from get_gt_seg_maps to get_one_gt_seg_map * rename gt_seg_map_loader_conf to gt_seg_map_loader_cfg * fix doc str * rename get_one_gt_seg_map to get_gt_seg_map_by_idx
This commit is contained in:
parent
56e18ba9aa
commit
5a7996db26
@ -12,7 +12,7 @@ from torch.utils.data import Dataset
|
|||||||
from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
|
from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
|
||||||
from mmseg.utils import get_root_logger
|
from mmseg.utils import get_root_logger
|
||||||
from .builder import DATASETS
|
from .builder import DATASETS
|
||||||
from .pipelines import Compose
|
from .pipelines import Compose, LoadAnnotations
|
||||||
|
|
||||||
|
|
||||||
@DATASETS.register_module()
|
@DATASETS.register_module()
|
||||||
@ -66,6 +66,8 @@ class CustomDataset(Dataset):
|
|||||||
The palette of segmentation map. If None is given, and
|
The palette of segmentation map. If None is given, and
|
||||||
self.PALETTE is None, random palette will be generated.
|
self.PALETTE is None, random palette will be generated.
|
||||||
Default: None
|
Default: None
|
||||||
|
gt_seg_map_loader_cfg (dict, optional): build LoadAnnotations to
|
||||||
|
load gt for evaluation, load from disk by default. Default: None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CLASSES = None
|
CLASSES = None
|
||||||
@ -84,7 +86,8 @@ class CustomDataset(Dataset):
|
|||||||
ignore_index=255,
|
ignore_index=255,
|
||||||
reduce_zero_label=False,
|
reduce_zero_label=False,
|
||||||
classes=None,
|
classes=None,
|
||||||
palette=None):
|
palette=None,
|
||||||
|
gt_seg_map_loader_cfg=None):
|
||||||
self.pipeline = Compose(pipeline)
|
self.pipeline = Compose(pipeline)
|
||||||
self.img_dir = img_dir
|
self.img_dir = img_dir
|
||||||
self.img_suffix = img_suffix
|
self.img_suffix = img_suffix
|
||||||
@ -98,6 +101,10 @@ class CustomDataset(Dataset):
|
|||||||
self.label_map = None
|
self.label_map = None
|
||||||
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
|
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
|
||||||
classes, palette)
|
classes, palette)
|
||||||
|
self.gt_seg_map_loader = LoadAnnotations(
|
||||||
|
) if gt_seg_map_loader_cfg is None else LoadAnnotations(
|
||||||
|
**gt_seg_map_loader_cfg)
|
||||||
|
|
||||||
if test_mode:
|
if test_mode:
|
||||||
assert self.CLASSES is not None, \
|
assert self.CLASSES is not None, \
|
||||||
'`cls.CLASSES` or `classes` should be specified when testing'
|
'`cls.CLASSES` or `classes` should be specified when testing'
|
||||||
@ -232,6 +239,14 @@ class CustomDataset(Dataset):
|
|||||||
"""Place holder to format result to dataset specific output."""
|
"""Place holder to format result to dataset specific output."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_gt_seg_map_by_idx(self, index):
|
||||||
|
"""Get one ground truth segmentation map for evaluation."""
|
||||||
|
ann_info = self.get_ann_info(index)
|
||||||
|
results = dict(ann_info=ann_info)
|
||||||
|
self.pre_pipeline(results)
|
||||||
|
self.gt_seg_map_loader(results)
|
||||||
|
return results['gt_semantic_seg']
|
||||||
|
|
||||||
def get_gt_seg_maps(self, efficient_test=None):
|
def get_gt_seg_maps(self, efficient_test=None):
|
||||||
"""Get ground truth segmentation maps for evaluation."""
|
"""Get ground truth segmentation maps for evaluation."""
|
||||||
if efficient_test is not None:
|
if efficient_test is not None:
|
||||||
@ -240,11 +255,12 @@ class CustomDataset(Dataset):
|
|||||||
'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory '
|
'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory '
|
||||||
'friendly by default. ')
|
'friendly by default. ')
|
||||||
|
|
||||||
for img_info in self.img_infos:
|
for idx in range(len(self)):
|
||||||
seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
|
ann_info = self.get_ann_info(idx)
|
||||||
gt_seg_map = mmcv.imread(
|
results = dict(ann_info=ann_info)
|
||||||
seg_map, flag='unchanged', backend='pillow')
|
self.pre_pipeline(results)
|
||||||
yield gt_seg_map
|
self.gt_seg_map_loader(results)
|
||||||
|
yield results['gt_semantic_seg']
|
||||||
|
|
||||||
def pre_eval(self, preds, indices):
|
def pre_eval(self, preds, indices):
|
||||||
"""Collect eval result from each iteration.
|
"""Collect eval result from each iteration.
|
||||||
@ -268,9 +284,7 @@ class CustomDataset(Dataset):
|
|||||||
pre_eval_results = []
|
pre_eval_results = []
|
||||||
|
|
||||||
for pred, index in zip(preds, indices):
|
for pred, index in zip(preds, indices):
|
||||||
seg_map = osp.join(self.ann_dir,
|
seg_map = self.get_gt_seg_map_by_idx(index)
|
||||||
self.img_infos[index]['ann']['seg_map'])
|
|
||||||
seg_map = mmcv.imread(seg_map, flag='unchanged', backend='pillow')
|
|
||||||
pre_eval_results.append(
|
pre_eval_results.append(
|
||||||
intersect_and_union(pred, seg_map, len(self.CLASSES),
|
intersect_and_union(pred, seg_map, len(self.CLASSES),
|
||||||
self.ignore_index, self.label_map,
|
self.ignore_index, self.label_map,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user