[Refactor] Refactor cityscapes metrics
@ -41,8 +41,9 @@ class PackSegInputs(BaseTransform):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'pad_shape',
|
||||
'scale_factor', 'flip', 'flip_direction')):
|
||||
meta_keys=('img_path', 'seg_map_path', 'ori_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction')):
|
||||
self.meta_keys = meta_keys
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .citys_metric import CitysMetric
|
||||
from .iou_metric import IoUMetric
|
||||
|
||||
__all__ = ['IoUMetric']
|
||||
__all__ = ['IoUMetric', 'CitysMetric']
|
||||
|
147
mmseg/metrics/citys_metric.py
Normal file
@ -0,0 +1,147 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class CitysMetric(BaseMetric):
|
||||
"""Cityscapes evaluation metric.
|
||||
|
||||
Args:
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
Default: 255.
|
||||
citys_metrics (list[str] | str): Metrics to be evaluated,
|
||||
Default: ['cityscapes'].
|
||||
to_label_id (bool): whether convert output to label_id for
|
||||
submission. Default: True.
|
||||
suffix (str): The filename prefix of the png files.
|
||||
If the prefix is "somepath/xxx", the png files will be
|
||||
named "somepath/xxx.png". Default: '.format_cityscapes'.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ignore_index: int = 255,
|
||||
citys_metrics: List[str] = ['cityscapes'],
|
||||
to_label_id: bool = True,
|
||||
suffix: str = '.format_cityscapes',
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
||||
self.ignore_index = ignore_index
|
||||
self.metrics = citys_metrics
|
||||
assert self.metrics[0] == 'cityscapes'
|
||||
self.to_label_id = to_label_id
|
||||
self.suffix = suffix
|
||||
|
||||
def process(self, data_batch: Sequence[dict],
|
||||
predictions: Sequence[dict]) -> None:
|
||||
"""Process one batch of data and predictions.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
mmcv.mkdir_or_exist(self.suffix)
|
||||
|
||||
for pred in predictions:
|
||||
pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy()
|
||||
# results2img
|
||||
if self.to_label_id:
|
||||
pred_label = self._convert_to_label_id(pred_label)
|
||||
basename = osp.splitext(osp.basename(pred['img_path']))[0]
|
||||
png_filename = osp.join(self.suffix, f'{basename}.png')
|
||||
output = Image.fromarray(pred_label.astype(np.uint8)).convert('P')
|
||||
import cityscapesscripts.helpers.labels as CSLabels
|
||||
palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
|
||||
for label_id, label in CSLabels.id2label.items():
|
||||
palette[label_id] = label.color
|
||||
output.putpalette(palette)
|
||||
output.save(png_filename)
|
||||
|
||||
ann_dir = osp.join(
|
||||
data_batch[0]['data_sample']['seg_map_path'].split('val')[0],
|
||||
'val')
|
||||
self.results.append(ann_dir)
|
||||
|
||||
def compute_metrics(self, results: list) -> Dict[str, float]:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): Testing results of the dataset.
|
||||
logger (logging.Logger | str | None): Logger used for printing
|
||||
related information during evaluation. Default: None.
|
||||
imgfile_prefix (str | None): The prefix of output image file
|
||||
|
||||
Returns:
|
||||
dict[str: float]: Cityscapes evaluation results.
|
||||
"""
|
||||
logger: MMLogger = MMLogger.get_current_instance()
|
||||
try:
|
||||
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
|
||||
except ImportError:
|
||||
raise ImportError('Please run "pip install cityscapesscripts" to '
|
||||
'install cityscapesscripts first.')
|
||||
msg = 'Evaluating in Cityscapes style'
|
||||
|
||||
if logger is None:
|
||||
msg = '\n' + msg
|
||||
print_log(msg, logger=logger)
|
||||
|
||||
result_dir = self.suffix
|
||||
|
||||
eval_results = dict()
|
||||
print_log(f'Evaluating results under {result_dir} ...', logger=logger)
|
||||
|
||||
CSEval.args.evalInstLevelScore = True
|
||||
CSEval.args.predictionPath = osp.abspath(result_dir)
|
||||
CSEval.args.evalPixelAccuracy = True
|
||||
CSEval.args.JSONOutput = False
|
||||
|
||||
seg_map_list = []
|
||||
pred_list = []
|
||||
ann_dir = results[0]
|
||||
# when evaluating with official cityscapesscripts,
|
||||
# **_gtFine_labelIds.png is used
|
||||
for seg_map in mmcv.scandir(
|
||||
ann_dir, 'gtFine_labelIds.png', recursive=True):
|
||||
seg_map_list.append(osp.join(ann_dir, seg_map))
|
||||
pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
|
||||
metric = dict()
|
||||
eval_results.update(
|
||||
CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
|
||||
metric['averageScoreCategories'] = eval_results[
|
||||
'averageScoreCategories']
|
||||
metric['averageScoreInstCategories'] = eval_results[
|
||||
'averageScoreInstCategories']
|
||||
return metric
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_label_id(result):
|
||||
"""Convert trainId to id for cityscapes."""
|
||||
if isinstance(result, str):
|
||||
result = np.load(result)
|
||||
import cityscapesscripts.helpers.labels as CSLabels
|
||||
result_copy = result.copy()
|
||||
for trainId, label in CSLabels.trainId2label.items():
|
||||
result_copy[result == trainId] = label.id
|
||||
|
||||
return result_copy
|
Before Width: | Height: | Size: 1.9 KiB After Width: | Height: | Size: 1.9 KiB |
Before Width: | Height: | Size: 1.5 KiB After Width: | Height: | Size: 1.5 KiB |
Before Width: | Height: | Size: 1.5 KiB After Width: | Height: | Size: 1.5 KiB |
Before Width: | Height: | Size: 50 KiB After Width: | Height: | Size: 50 KiB |
@ -173,10 +173,10 @@ def test_cityscapes():
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_cityscapes_dataset/leftImg8bit'),
|
||||
'../data/pseudo_cityscapes_dataset/leftImg8bit/val'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_cityscapes_dataset/gtFine')))
|
||||
'../data/pseudo_cityscapes_dataset/gtFine/val')))
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
|
112
tests/test_metrics/test_citys_metric.py
Normal file
@ -0,0 +1,112 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.data import BaseDataElement, PixelData
|
||||
|
||||
from mmseg.core import SegDataSample
|
||||
from mmseg.metrics import CitysMetric
|
||||
|
||||
|
||||
class TestCitysMetric(TestCase):
|
||||
|
||||
def _demo_mm_inputs(self,
|
||||
batch_size=1,
|
||||
image_shapes=(3, 128, 256),
|
||||
num_classes=5):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
batch_size (int): batch size. Default to 2.
|
||||
image_shapes (List[tuple], Optional): image shape.
|
||||
Default to (3, 64, 64)
|
||||
num_classes (int): number of different classes.
|
||||
Default to 5.
|
||||
"""
|
||||
if isinstance(image_shapes, list):
|
||||
assert len(image_shapes) == batch_size
|
||||
else:
|
||||
image_shapes = [image_shapes] * batch_size
|
||||
|
||||
packed_inputs = []
|
||||
for idx in range(batch_size):
|
||||
image_shape = image_shapes[idx]
|
||||
_, h, w = image_shape
|
||||
|
||||
mm_inputs = dict()
|
||||
data_sample = SegDataSample()
|
||||
gt_semantic_seg = np.random.randint(
|
||||
0, num_classes, (1, h, w), dtype=np.uint8)
|
||||
gt_semantic_seg = torch.LongTensor(gt_semantic_seg)
|
||||
gt_sem_seg_data = dict(data=gt_semantic_seg)
|
||||
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
mm_inputs['data_sample'] = data_sample.to_dict()
|
||||
mm_inputs['data_sample']['seg_map_path'] = \
|
||||
'tests/data/pseudo_cityscapes_dataset/gtFine/val/\
|
||||
frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png'
|
||||
|
||||
packed_inputs.append(mm_inputs)
|
||||
|
||||
return packed_inputs
|
||||
|
||||
def _demo_mm_model_output(self,
|
||||
batch_size=1,
|
||||
image_shapes=(3, 128, 256),
|
||||
num_classes=5):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
batch_size (int): batch size. Default to 2.
|
||||
image_shapes (List[tuple], Optional): image shape.
|
||||
Default to (3, 64, 64)
|
||||
num_classes (int): number of different classes.
|
||||
Default to 5.
|
||||
"""
|
||||
results_dict = dict()
|
||||
_, h, w = image_shapes
|
||||
seg_logit = torch.randn(batch_size, num_classes, h, w)
|
||||
results_dict['seg_logits'] = seg_logit
|
||||
seg_pred = np.random.randint(
|
||||
0, num_classes, (batch_size, h, w), dtype=np.uint8)
|
||||
seg_pred = torch.LongTensor(seg_pred)
|
||||
results_dict['pred_sem_seg'] = seg_pred
|
||||
|
||||
batch_datasampes = [
|
||||
SegDataSample()
|
||||
for _ in range(results_dict['pred_sem_seg'].shape[0])
|
||||
]
|
||||
for key, value in results_dict.items():
|
||||
for i in range(value.shape[0]):
|
||||
setattr(batch_datasampes[i], key, PixelData(data=value[i]))
|
||||
|
||||
_predictions = []
|
||||
for pred in batch_datasampes:
|
||||
if isinstance(pred, BaseDataElement):
|
||||
test_data = pred.to_dict()
|
||||
test_data['img_path'] = \
|
||||
'tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/\
|
||||
frankfurt/frankfurt_000000_000294_leftImg8bit.png'
|
||||
|
||||
_predictions.append(test_data)
|
||||
else:
|
||||
_predictions.append(pred)
|
||||
return _predictions
|
||||
|
||||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
|
||||
data_batch = self._demo_mm_inputs()
|
||||
predictions = self._demo_mm_model_output()
|
||||
iou_metric = CitysMetric(citys_metrics=['cityscapes'])
|
||||
iou_metric.process(data_batch, predictions)
|
||||
res = iou_metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
# test to_label_id = True
|
||||
iou_metric = CitysMetric(
|
||||
citys_metrics=['cityscapes'], to_label_id=True)
|
||||
iou_metric.process(data_batch, predictions)
|
||||
res = iou_metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
import shutil
|
||||
shutil.rmtree('.format_cityscapes')
|