[Refactor] Refactor cityscapes metrics

This commit is contained in:
linfangjian.vendor 2022-06-28 03:21:33 +00:00 committed by zhengmiao
parent eef12a064b
commit 6053345b3d
9 changed files with 266 additions and 5 deletions

View File

@ -41,8 +41,9 @@ class PackSegInputs(BaseTransform):
"""
def __init__(self,
meta_keys=('img_path', 'ori_shape', 'img_shape', 'pad_shape',
'scale_factor', 'flip', 'flip_direction')):
meta_keys=('img_path', 'seg_map_path', 'ori_shape',
'img_shape', 'pad_shape', 'scale_factor', 'flip',
'flip_direction')):
self.meta_keys = meta_keys
def transform(self, results: dict) -> dict:

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .citys_metric import CitysMetric
from .iou_metric import IoUMetric
__all__ = ['IoUMetric']
__all__ = ['IoUMetric', 'CitysMetric']

View File

@ -0,0 +1,147 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Dict, List, Optional, Sequence
import mmcv
import numpy as np
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger, print_log
from PIL import Image
from mmseg.registry import METRICS
@METRICS.register_module()
class CitysMetric(BaseMetric):
"""Cityscapes evaluation metric.
Args:
ignore_index (int): Index that will be ignored in evaluation.
Default: 255.
citys_metrics (list[str] | str): Metrics to be evaluated,
Default: ['cityscapes'].
to_label_id (bool): whether convert output to label_id for
submission. Default: True.
suffix (str): The filename prefix of the png files.
If the prefix is "somepath/xxx", the png files will be
named "somepath/xxx.png". Default: '.format_cityscapes'.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
"""
def __init__(self,
ignore_index: int = 255,
citys_metrics: List[str] = ['cityscapes'],
to_label_id: bool = True,
suffix: str = '.format_cityscapes',
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
self.ignore_index = ignore_index
self.metrics = citys_metrics
assert self.metrics[0] == 'cityscapes'
self.to_label_id = to_label_id
self.suffix = suffix
def process(self, data_batch: Sequence[dict],
predictions: Sequence[dict]) -> None:
"""Process one batch of data and predictions.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from the model.
"""
mmcv.mkdir_or_exist(self.suffix)
for pred in predictions:
pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy()
# results2img
if self.to_label_id:
pred_label = self._convert_to_label_id(pred_label)
basename = osp.splitext(osp.basename(pred['img_path']))[0]
png_filename = osp.join(self.suffix, f'{basename}.png')
output = Image.fromarray(pred_label.astype(np.uint8)).convert('P')
import cityscapesscripts.helpers.labels as CSLabels
palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
for label_id, label in CSLabels.id2label.items():
palette[label_id] = label.color
output.putpalette(palette)
output.save(png_filename)
ann_dir = osp.join(
data_batch[0]['data_sample']['seg_map_path'].split('val')[0],
'val')
self.results.append(ann_dir)
def compute_metrics(self, results: list) -> Dict[str, float]:
"""Compute the metrics from processed results.
Args:
results (list): Testing results of the dataset.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
imgfile_prefix (str | None): The prefix of output image file
Returns:
dict[str: float]: Cityscapes evaluation results.
"""
logger: MMLogger = MMLogger.get_current_instance()
try:
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
except ImportError:
raise ImportError('Please run "pip install cityscapesscripts" to '
'install cityscapesscripts first.')
msg = 'Evaluating in Cityscapes style'
if logger is None:
msg = '\n' + msg
print_log(msg, logger=logger)
result_dir = self.suffix
eval_results = dict()
print_log(f'Evaluating results under {result_dir} ...', logger=logger)
CSEval.args.evalInstLevelScore = True
CSEval.args.predictionPath = osp.abspath(result_dir)
CSEval.args.evalPixelAccuracy = True
CSEval.args.JSONOutput = False
seg_map_list = []
pred_list = []
ann_dir = results[0]
# when evaluating with official cityscapesscripts,
# **_gtFine_labelIds.png is used
for seg_map in mmcv.scandir(
ann_dir, 'gtFine_labelIds.png', recursive=True):
seg_map_list.append(osp.join(ann_dir, seg_map))
pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
metric = dict()
eval_results.update(
CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
metric['averageScoreCategories'] = eval_results[
'averageScoreCategories']
metric['averageScoreInstCategories'] = eval_results[
'averageScoreInstCategories']
return metric
@staticmethod
def _convert_to_label_id(result):
"""Convert trainId to id for cityscapes."""
if isinstance(result, str):
result = np.load(result)
import cityscapesscripts.helpers.labels as CSLabels
result_copy = result.copy()
for trainId, label in CSLabels.trainId2label.items():
result_copy[result == trainId] = label.id
return result_copy

View File

@ -173,10 +173,10 @@ def test_cityscapes():
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_cityscapes_dataset/leftImg8bit'),
'../data/pseudo_cityscapes_dataset/leftImg8bit/val'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_cityscapes_dataset/gtFine')))
'../data/pseudo_cityscapes_dataset/gtFine/val')))
assert len(test_dataset) == 1

View File

@ -0,0 +1,112 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmengine.data import BaseDataElement, PixelData
from mmseg.core import SegDataSample
from mmseg.metrics import CitysMetric
class TestCitysMetric(TestCase):
def _demo_mm_inputs(self,
batch_size=1,
image_shapes=(3, 128, 256),
num_classes=5):
"""Create a superset of inputs needed to run test or train batches.
Args:
batch_size (int): batch size. Default to 2.
image_shapes (List[tuple], Optional): image shape.
Default to (3, 64, 64)
num_classes (int): number of different classes.
Default to 5.
"""
if isinstance(image_shapes, list):
assert len(image_shapes) == batch_size
else:
image_shapes = [image_shapes] * batch_size
packed_inputs = []
for idx in range(batch_size):
image_shape = image_shapes[idx]
_, h, w = image_shape
mm_inputs = dict()
data_sample = SegDataSample()
gt_semantic_seg = np.random.randint(
0, num_classes, (1, h, w), dtype=np.uint8)
gt_semantic_seg = torch.LongTensor(gt_semantic_seg)
gt_sem_seg_data = dict(data=gt_semantic_seg)
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
mm_inputs['data_sample'] = data_sample.to_dict()
mm_inputs['data_sample']['seg_map_path'] = \
'tests/data/pseudo_cityscapes_dataset/gtFine/val/\
frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png'
packed_inputs.append(mm_inputs)
return packed_inputs
def _demo_mm_model_output(self,
batch_size=1,
image_shapes=(3, 128, 256),
num_classes=5):
"""Create a superset of inputs needed to run test or train batches.
Args:
batch_size (int): batch size. Default to 2.
image_shapes (List[tuple], Optional): image shape.
Default to (3, 64, 64)
num_classes (int): number of different classes.
Default to 5.
"""
results_dict = dict()
_, h, w = image_shapes
seg_logit = torch.randn(batch_size, num_classes, h, w)
results_dict['seg_logits'] = seg_logit
seg_pred = np.random.randint(
0, num_classes, (batch_size, h, w), dtype=np.uint8)
seg_pred = torch.LongTensor(seg_pred)
results_dict['pred_sem_seg'] = seg_pred
batch_datasampes = [
SegDataSample()
for _ in range(results_dict['pred_sem_seg'].shape[0])
]
for key, value in results_dict.items():
for i in range(value.shape[0]):
setattr(batch_datasampes[i], key, PixelData(data=value[i]))
_predictions = []
for pred in batch_datasampes:
if isinstance(pred, BaseDataElement):
test_data = pred.to_dict()
test_data['img_path'] = \
'tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/\
frankfurt/frankfurt_000000_000294_leftImg8bit.png'
_predictions.append(test_data)
else:
_predictions.append(pred)
return _predictions
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
data_batch = self._demo_mm_inputs()
predictions = self._demo_mm_model_output()
iou_metric = CitysMetric(citys_metrics=['cityscapes'])
iou_metric.process(data_batch, predictions)
res = iou_metric.evaluate(6)
self.assertIsInstance(res, dict)
# test to_label_id = True
iou_metric = CitysMetric(
citys_metrics=['cityscapes'], to_label_id=True)
iou_metric.process(data_batch, predictions)
res = iou_metric.evaluate(6)
self.assertIsInstance(res, dict)
import shutil
shutil.rmtree('.format_cityscapes')