From 6053345b3d653d2c277b976f3399aed41352e455 Mon Sep 17 00:00:00 2001 From: "linfangjian.vendor" Date: Tue, 28 Jun 2022 03:21:33 +0000 Subject: [PATCH] [Refactor] Refactor cityscapes metrics --- mmseg/datasets/pipelines/formatting.py | 5 +- mmseg/metrics/__init__.py | 3 +- mmseg/metrics/citys_metric.py | 147 ++++++++++++++++++ ...kfurt_000000_000294_gtFine_instanceIds.png | Bin ...rankfurt_000000_000294_gtFine_labelIds.png | Bin ...urt_000000_000294_gtFine_labelTrainIds.png | Bin .../frankfurt_000000_000294_leftImg8bit.png | Bin tests/test_datasets/test_dataset.py | 4 +- tests/test_metrics/test_citys_metric.py | 112 +++++++++++++ 9 files changed, 266 insertions(+), 5 deletions(-) create mode 100644 mmseg/metrics/citys_metric.py rename tests/data/pseudo_cityscapes_dataset/gtFine/{ => val/frankfurt}/frankfurt_000000_000294_gtFine_instanceIds.png (100%) rename tests/data/pseudo_cityscapes_dataset/gtFine/{ => val/frankfurt}/frankfurt_000000_000294_gtFine_labelIds.png (100%) rename tests/data/pseudo_cityscapes_dataset/gtFine/{ => val/frankfurt}/frankfurt_000000_000294_gtFine_labelTrainIds.png (100%) rename tests/data/pseudo_cityscapes_dataset/leftImg8bit/{ => val/frankfurt}/frankfurt_000000_000294_leftImg8bit.png (100%) create mode 100644 tests/test_metrics/test_citys_metric.py diff --git a/mmseg/datasets/pipelines/formatting.py b/mmseg/datasets/pipelines/formatting.py index 9c390022a..7bb3075d6 100644 --- a/mmseg/datasets/pipelines/formatting.py +++ b/mmseg/datasets/pipelines/formatting.py @@ -41,8 +41,9 @@ class PackSegInputs(BaseTransform): """ def __init__(self, - meta_keys=('img_path', 'ori_shape', 'img_shape', 'pad_shape', - 'scale_factor', 'flip', 'flip_direction')): + meta_keys=('img_path', 'seg_map_path', 'ori_shape', + 'img_shape', 'pad_shape', 'scale_factor', 'flip', + 'flip_direction')): self.meta_keys = meta_keys def transform(self, results: dict) -> dict: diff --git a/mmseg/metrics/__init__.py b/mmseg/metrics/__init__.py index 73cb09127..aec08bb07 100644 --- a/mmseg/metrics/__init__.py +++ b/mmseg/metrics/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .citys_metric import CitysMetric from .iou_metric import IoUMetric -__all__ = ['IoUMetric'] +__all__ = ['IoUMetric', 'CitysMetric'] diff --git a/mmseg/metrics/citys_metric.py b/mmseg/metrics/citys_metric.py new file mode 100644 index 000000000..73516e778 --- /dev/null +++ b/mmseg/metrics/citys_metric.py @@ -0,0 +1,147 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Dict, List, Optional, Sequence + +import mmcv +import numpy as np +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger, print_log +from PIL import Image + +from mmseg.registry import METRICS + + +@METRICS.register_module() +class CitysMetric(BaseMetric): + """Cityscapes evaluation metric. + + Args: + ignore_index (int): Index that will be ignored in evaluation. + Default: 255. + citys_metrics (list[str] | str): Metrics to be evaluated, + Default: ['cityscapes']. + to_label_id (bool): whether convert output to label_id for + submission. Default: True. + suffix (str): The filename prefix of the png files. + If the prefix is "somepath/xxx", the png files will be + named "somepath/xxx.png". Default: '.format_cityscapes'. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + + def __init__(self, + ignore_index: int = 255, + citys_metrics: List[str] = ['cityscapes'], + to_label_id: bool = True, + suffix: str = '.format_cityscapes', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + self.ignore_index = ignore_index + self.metrics = citys_metrics + assert self.metrics[0] == 'cityscapes' + self.to_label_id = to_label_id + self.suffix = suffix + + def process(self, data_batch: Sequence[dict], + predictions: Sequence[dict]) -> None: + """Process one batch of data and predictions. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch (Sequence[dict]): A batch of data from the dataloader. + predictions (Sequence[dict]): A batch of outputs from the model. + """ + mmcv.mkdir_or_exist(self.suffix) + + for pred in predictions: + pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy() + # results2img + if self.to_label_id: + pred_label = self._convert_to_label_id(pred_label) + basename = osp.splitext(osp.basename(pred['img_path']))[0] + png_filename = osp.join(self.suffix, f'{basename}.png') + output = Image.fromarray(pred_label.astype(np.uint8)).convert('P') + import cityscapesscripts.helpers.labels as CSLabels + palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8) + for label_id, label in CSLabels.id2label.items(): + palette[label_id] = label.color + output.putpalette(palette) + output.save(png_filename) + + ann_dir = osp.join( + data_batch[0]['data_sample']['seg_map_path'].split('val')[0], + 'val') + self.results.append(ann_dir) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): Testing results of the dataset. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + imgfile_prefix (str | None): The prefix of output image file + + Returns: + dict[str: float]: Cityscapes evaluation results. + """ + logger: MMLogger = MMLogger.get_current_instance() + try: + import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa + except ImportError: + raise ImportError('Please run "pip install cityscapesscripts" to ' + 'install cityscapesscripts first.') + msg = 'Evaluating in Cityscapes style' + + if logger is None: + msg = '\n' + msg + print_log(msg, logger=logger) + + result_dir = self.suffix + + eval_results = dict() + print_log(f'Evaluating results under {result_dir} ...', logger=logger) + + CSEval.args.evalInstLevelScore = True + CSEval.args.predictionPath = osp.abspath(result_dir) + CSEval.args.evalPixelAccuracy = True + CSEval.args.JSONOutput = False + + seg_map_list = [] + pred_list = [] + ann_dir = results[0] + # when evaluating with official cityscapesscripts, + # **_gtFine_labelIds.png is used + for seg_map in mmcv.scandir( + ann_dir, 'gtFine_labelIds.png', recursive=True): + seg_map_list.append(osp.join(ann_dir, seg_map)) + pred_list.append(CSEval.getPrediction(CSEval.args, seg_map)) + metric = dict() + eval_results.update( + CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args)) + metric['averageScoreCategories'] = eval_results[ + 'averageScoreCategories'] + metric['averageScoreInstCategories'] = eval_results[ + 'averageScoreInstCategories'] + return metric + + @staticmethod + def _convert_to_label_id(result): + """Convert trainId to id for cityscapes.""" + if isinstance(result, str): + result = np.load(result) + import cityscapesscripts.helpers.labels as CSLabels + result_copy = result.copy() + for trainId, label in CSLabels.trainId2label.items(): + result_copy[result == trainId] = label.id + + return result_copy diff --git a/tests/data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_instanceIds.png b/tests/data/pseudo_cityscapes_dataset/gtFine/val/frankfurt/frankfurt_000000_000294_gtFine_instanceIds.png similarity index 100% rename from tests/data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_instanceIds.png rename to tests/data/pseudo_cityscapes_dataset/gtFine/val/frankfurt/frankfurt_000000_000294_gtFine_instanceIds.png diff --git a/tests/data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_labelIds.png b/tests/data/pseudo_cityscapes_dataset/gtFine/val/frankfurt/frankfurt_000000_000294_gtFine_labelIds.png similarity index 100% rename from tests/data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_labelIds.png rename to tests/data/pseudo_cityscapes_dataset/gtFine/val/frankfurt/frankfurt_000000_000294_gtFine_labelIds.png diff --git a/tests/data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_labelTrainIds.png b/tests/data/pseudo_cityscapes_dataset/gtFine/val/frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png similarity index 100% rename from tests/data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_labelTrainIds.png rename to tests/data/pseudo_cityscapes_dataset/gtFine/val/frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png diff --git a/tests/data/pseudo_cityscapes_dataset/leftImg8bit/frankfurt_000000_000294_leftImg8bit.png b/tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/frankfurt/frankfurt_000000_000294_leftImg8bit.png similarity index 100% rename from tests/data/pseudo_cityscapes_dataset/leftImg8bit/frankfurt_000000_000294_leftImg8bit.png rename to tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/frankfurt/frankfurt_000000_000294_leftImg8bit.png diff --git a/tests/test_datasets/test_dataset.py b/tests/test_datasets/test_dataset.py index 307cc44cb..e42489385 100644 --- a/tests/test_datasets/test_dataset.py +++ b/tests/test_datasets/test_dataset.py @@ -173,10 +173,10 @@ def test_cityscapes(): data_prefix=dict( img_path=osp.join( osp.dirname(__file__), - '../data/pseudo_cityscapes_dataset/leftImg8bit'), + '../data/pseudo_cityscapes_dataset/leftImg8bit/val'), seg_map_path=osp.join( osp.dirname(__file__), - '../data/pseudo_cityscapes_dataset/gtFine'))) + '../data/pseudo_cityscapes_dataset/gtFine/val'))) assert len(test_dataset) == 1 diff --git a/tests/test_metrics/test_citys_metric.py b/tests/test_metrics/test_citys_metric.py new file mode 100644 index 000000000..7d5088618 --- /dev/null +++ b/tests/test_metrics/test_citys_metric.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np +import torch +from mmengine.data import BaseDataElement, PixelData + +from mmseg.core import SegDataSample +from mmseg.metrics import CitysMetric + + +class TestCitysMetric(TestCase): + + def _demo_mm_inputs(self, + batch_size=1, + image_shapes=(3, 128, 256), + num_classes=5): + """Create a superset of inputs needed to run test or train batches. + + Args: + batch_size (int): batch size. Default to 2. + image_shapes (List[tuple], Optional): image shape. + Default to (3, 64, 64) + num_classes (int): number of different classes. + Default to 5. + """ + if isinstance(image_shapes, list): + assert len(image_shapes) == batch_size + else: + image_shapes = [image_shapes] * batch_size + + packed_inputs = [] + for idx in range(batch_size): + image_shape = image_shapes[idx] + _, h, w = image_shape + + mm_inputs = dict() + data_sample = SegDataSample() + gt_semantic_seg = np.random.randint( + 0, num_classes, (1, h, w), dtype=np.uint8) + gt_semantic_seg = torch.LongTensor(gt_semantic_seg) + gt_sem_seg_data = dict(data=gt_semantic_seg) + data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data) + mm_inputs['data_sample'] = data_sample.to_dict() + mm_inputs['data_sample']['seg_map_path'] = \ + 'tests/data/pseudo_cityscapes_dataset/gtFine/val/\ + frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png' + + packed_inputs.append(mm_inputs) + + return packed_inputs + + def _demo_mm_model_output(self, + batch_size=1, + image_shapes=(3, 128, 256), + num_classes=5): + """Create a superset of inputs needed to run test or train batches. + + Args: + batch_size (int): batch size. Default to 2. + image_shapes (List[tuple], Optional): image shape. + Default to (3, 64, 64) + num_classes (int): number of different classes. + Default to 5. + """ + results_dict = dict() + _, h, w = image_shapes + seg_logit = torch.randn(batch_size, num_classes, h, w) + results_dict['seg_logits'] = seg_logit + seg_pred = np.random.randint( + 0, num_classes, (batch_size, h, w), dtype=np.uint8) + seg_pred = torch.LongTensor(seg_pred) + results_dict['pred_sem_seg'] = seg_pred + + batch_datasampes = [ + SegDataSample() + for _ in range(results_dict['pred_sem_seg'].shape[0]) + ] + for key, value in results_dict.items(): + for i in range(value.shape[0]): + setattr(batch_datasampes[i], key, PixelData(data=value[i])) + + _predictions = [] + for pred in batch_datasampes: + if isinstance(pred, BaseDataElement): + test_data = pred.to_dict() + test_data['img_path'] = \ + 'tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/\ + frankfurt/frankfurt_000000_000294_leftImg8bit.png' + + _predictions.append(test_data) + else: + _predictions.append(pred) + return _predictions + + def test_evaluate(self): + """Test using the metric in the same way as Evalutor.""" + + data_batch = self._demo_mm_inputs() + predictions = self._demo_mm_model_output() + iou_metric = CitysMetric(citys_metrics=['cityscapes']) + iou_metric.process(data_batch, predictions) + res = iou_metric.evaluate(6) + self.assertIsInstance(res, dict) + # test to_label_id = True + iou_metric = CitysMetric( + citys_metrics=['cityscapes'], to_label_id=True) + iou_metric.process(data_batch, predictions) + res = iou_metric.evaluate(6) + self.assertIsInstance(res, dict) + import shutil + shutil.rmtree('.format_cityscapes')