From 593229eb16591493bc6219c31ecda54e63b2498b Mon Sep 17 00:00:00 2001 From: "linfangjian.vendor" Date: Thu, 2 Jun 2022 14:15:28 +0000 Subject: [PATCH] [Refactor] Refactor IoU metrics --- mmseg/metrics/__init__.py | 4 + mmseg/metrics/iou_metric.py | 278 ++++++++++++++++++++++++++ mmseg/utils/set_env.py | 1 + tests/test_metrics/test_iou_metric.py | 99 +++++++++ 4 files changed, 382 insertions(+) create mode 100644 mmseg/metrics/__init__.py create mode 100644 mmseg/metrics/iou_metric.py create mode 100644 tests/test_metrics/test_iou_metric.py diff --git a/mmseg/metrics/__init__.py b/mmseg/metrics/__init__.py new file mode 100644 index 000000000..73cb09127 --- /dev/null +++ b/mmseg/metrics/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .iou_metric import IoUMetric + +__all__ = ['IoUMetric'] diff --git a/mmseg/metrics/iou_metric.py b/mmseg/metrics/iou_metric.py new file mode 100644 index 000000000..c7964f33f --- /dev/null +++ b/mmseg/metrics/iou_metric.py @@ -0,0 +1,278 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from typing import Dict, List, Optional, Sequence, Union + +import mmcv +import numpy as np +import torch +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger, print_log +from prettytable import PrettyTable + +from mmseg.registry import METRICS + + +@METRICS.register_module() +class IoUMetric(BaseMetric): + """IoU evaluation metric. + + Args: + ignore_index (int): Index that will be ignored in evaluation. + Default: 255. + metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + beta (int): Determines the weight of recall in the combined score. + Default: 1. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + + def __init__(self, + ignore_index: int = 255, + metrics: List[str] = ['mIoU'], + nan_to_num: Optional[int] = None, + beta: int = 1, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + self.ignore_index = ignore_index + self.metrics = metrics + self.nan_to_num = nan_to_num + self.beta = beta + + def process(self, data_batch: Sequence[dict], + predictions: Sequence[dict]) -> None: + """Process one batch of data and predictions. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch (Sequence[dict]): A batch of data from the dataloader. + predictions (Sequence[dict]): A batch of outputs from the model. + """ + num_classes = len(self.dataset_meta['classes']) + label_map = self.dataset_meta['label_map'] + reduce_zero_label = self.dataset_meta['reduce_zero_label'] + for data, pred in zip(data_batch, predictions): + label = data['data_sample']['gt_sem_seg']['data'][0].cpu().numpy() + pred_label = pred['pred_sem_seg']['data'][0] + self.results.append( + self.intersect_and_union(pred_label, label, num_classes, + self.ignore_index, label_map, + reduce_zero_label)) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. The key + mainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision, + mRecall. + """ + logger: MMLogger = MMLogger.get_current_instance() + + # convert list of tuples to tuple of lists, e.g. + # [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to + # ([A_1, ..., A_n], ..., [D_1, ..., D_n]) + results = tuple(zip(*results)) + assert len(results) == 4 + + total_area_intersect = sum(results[0]) + total_area_union = sum(results[1]) + total_area_pred_label = sum(results[2]) + total_area_label = sum(results[3]) + ret_metrics = self.total_area_to_metrics( + total_area_intersect, total_area_union, total_area_pred_label, + total_area_label, self.metrics, self.nan_to_num, self.beta) + + class_names = self.dataset_meta['classes'] + + # summary table + ret_metrics_summary = OrderedDict({ + ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + metrics = dict() + for key, val in ret_metrics_summary.items(): + if key == 'aAcc': + metrics[key] = val + else: + metrics['m' + key] = val + + # each class table + ret_metrics.pop('aAcc', None) + ret_metrics_class = OrderedDict({ + ret_metric: np.round(ret_metric_value * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + ret_metrics_class.update({'Class': class_names}) + ret_metrics_class.move_to_end('Class', last=False) + class_table_data = PrettyTable() + for key, val in ret_metrics_class.items(): + class_table_data.add_column(key, val) + + print_log('per class results:', logger) + print_log('\n' + class_table_data.get_string(), logger=logger) + + return metrics + + @staticmethod + def intersect_and_union(pred_label: Union[np.ndarray, str], + label: Union[np.ndarray, str], + num_classes: int, + ignore_index: int, + label_map: dict = dict(), + reduce_zero_label: bool = False): + """Calculate intersection and Union. + + Args: + pred_label (ndarray | str): Prediction segmentation map + or predict result filename. The shape is (H, W). + label (ndarray | str): Ground truth segmentation map + or label filename. The shape is (H, W). + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + label_map (dict): Mapping old labels to new labels. The parameter + will work only when label is str. Default: dict(). + reduce_zero_label (bool): Whether ignore zero label. The parameter + will work only when label is str. Default: False. + + Returns: + torch.Tensor: The intersection of prediction and ground truth + histogram on all classes. + torch.Tensor: The union of prediction and ground truth histogram on + all classes. + torch.Tensor: The prediction histogram on all classes. + torch.Tensor: The ground truth histogram on all classes. + """ + + if isinstance(pred_label, str): + pred_label = torch.from_numpy(np.load(pred_label)) + else: + pred_label = torch.from_numpy((pred_label)) + + if isinstance(label, str): + label = torch.from_numpy( + mmcv.imread(label, flag='unchanged', backend='pillow')) + else: + label = torch.from_numpy(label) + + if label_map is not None: + for old_id, new_id in label_map.items(): + label[label == old_id] = new_id + if reduce_zero_label: + label[label == 0] = 255 + label = label - 1 + label[label == 254] = 255 + + mask = (label != ignore_index) + pred_label = pred_label[mask] + label = label[mask] + + intersect = pred_label[pred_label == label] + area_intersect = torch.histc( + intersect.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_pred_label = torch.histc( + pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_label = torch.histc( + label.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_union = area_pred_label + area_label - area_intersect + return area_intersect, area_union, area_pred_label, area_label + + @staticmethod + def total_area_to_metrics(total_area_intersect: np.ndarray, + total_area_union: np.ndarray, + total_area_pred_label: np.ndarray, + total_area_label: np.ndarray, + metrics: List[str] = ['mIoU'], + nan_to_num: Optional[int] = None, + beta: int = 1): + """Calculate evaluation metrics + Args: + total_area_intersect (ndarray): The intersection of prediction and + ground truth histogram on all classes. + total_area_union (ndarray): The union of prediction and ground + truth histogram on all classes. + total_area_pred_label (ndarray): The prediction histogram on all + classes. + total_area_label (ndarray): The ground truth histogram on + all classes. + metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and + 'mDice'. + nan_to_num (int, optional): If specified, NaN values will be + replaced by the numbers defined by the user. Default: None. + beta (int): Determines the weight of recall in the combined score. + Default: 1. + Returns: + Dict[str, ndarray]: per category evaluation metrics, + shape (num_classes, ). + """ + + def f_score(precision, recall, beta=1): + """calculate the f-score value. + + Args: + precision (float | torch.Tensor): The precision value. + recall (float | torch.Tensor): The recall value. + beta (int): Determines the weight of recall in the combined + score. Default: 1. + + Returns: + [torch.tensor]: The f-score value. + """ + score = (1 + beta**2) * (precision * recall) / ( + (beta**2 * precision) + recall) + return score + + if isinstance(metrics, str): + metrics = [metrics] + allowed_metrics = ['mIoU', 'mDice', 'mFscore'] + if not set(metrics).issubset(set(allowed_metrics)): + raise KeyError('metrics {} is not supported'.format(metrics)) + + all_acc = total_area_intersect.sum() / total_area_label.sum() + ret_metrics = OrderedDict({'aAcc': all_acc}) + for metric in metrics: + if metric == 'mIoU': + iou = total_area_intersect / total_area_union + acc = total_area_intersect / total_area_label + ret_metrics['IoU'] = iou + ret_metrics['Acc'] = acc + elif metric == 'mDice': + dice = 2 * total_area_intersect / ( + total_area_pred_label + total_area_label) + acc = total_area_intersect / total_area_label + ret_metrics['Dice'] = dice + ret_metrics['Acc'] = acc + elif metric == 'mFscore': + precision = total_area_intersect / total_area_pred_label + recall = total_area_intersect / total_area_label + f_value = torch.tensor([ + f_score(x[0], x[1], beta) for x in zip(precision, recall) + ]) + ret_metrics['Fscore'] = f_value + ret_metrics['Precision'] = precision + ret_metrics['Recall'] = recall + + ret_metrics = { + metric: value.numpy() + for metric, value in ret_metrics.items() + } + if nan_to_num is not None: + ret_metrics = OrderedDict({ + metric: np.nan_to_num(metric_value, nan=nan_to_num) + for metric, metric_value in ret_metrics.items() + }) + return ret_metrics diff --git a/mmseg/utils/set_env.py b/mmseg/utils/set_env.py index 2f6efbe8a..3db9c46d0 100644 --- a/mmseg/utils/set_env.py +++ b/mmseg/utils/set_env.py @@ -72,6 +72,7 @@ def register_all_modules(init_default_scope: bool = True) -> None: import mmseg.core # noqa: F401,F403 import mmseg.datasets # noqa: F401,F403 import mmseg.datasets.pipelines # noqa: F401,F403 + import mmseg.metrics # noqa: F401,F403 import mmseg.models # noqa: F401,F403 if init_default_scope: diff --git a/tests/test_metrics/test_iou_metric.py b/tests/test_metrics/test_iou_metric.py new file mode 100644 index 000000000..58d238bfd --- /dev/null +++ b/tests/test_metrics/test_iou_metric.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np +import torch +from mmengine.data import BaseDataElement, PixelData + +from mmseg.core import SegDataSample +from mmseg.metrics import IoUMetric + + +class TestIoUMetric(TestCase): + + def _demo_mm_inputs(self, + batch_size=2, + image_shapes=(3, 64, 64), + num_classes=5): + """Create a superset of inputs needed to run test or train batches. + + Args: + batch_size (int): batch size. Default to 2. + image_shapes (List[tuple], Optional): image shape. + Default to (3, 64, 64) + num_classes (int): number of different classes. + Default to 5. + """ + if isinstance(image_shapes, list): + assert len(image_shapes) == batch_size + else: + image_shapes = [image_shapes] * batch_size + + packed_inputs = [] + for idx in range(batch_size): + image_shape = image_shapes[idx] + _, h, w = image_shape + + mm_inputs = dict() + data_sample = SegDataSample() + gt_semantic_seg = np.random.randint( + 0, num_classes, (1, h, w), dtype=np.uint8) + gt_semantic_seg = torch.LongTensor(gt_semantic_seg) + gt_sem_seg_data = dict(data=gt_semantic_seg) + data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data) + mm_inputs['data_sample'] = data_sample.to_dict() + packed_inputs.append(mm_inputs) + + return packed_inputs + + def _demo_mm_model_output(self, + batch_size=2, + image_shapes=(3, 64, 64), + num_classes=5): + """Create a superset of inputs needed to run test or train batches. + + Args: + batch_size (int): batch size. Default to 2. + image_shapes (List[tuple], Optional): image shape. + Default to (3, 64, 64) + num_classes (int): number of different classes. + Default to 5. + """ + results_dict = dict() + _, h, w = image_shapes + seg_logit = torch.randn(batch_size, num_classes, h, w) + results_dict['seg_logits'] = seg_logit + seg_pred = np.random.randint( + 0, num_classes, (batch_size, h, w), dtype=np.uint8) + results_dict['pred_sem_seg'] = seg_pred + + batch_datasampes = [ + SegDataSample() + for _ in range(results_dict['pred_sem_seg'].shape[0]) + ] + for key, value in results_dict.items(): + for i in range(value.shape[0]): + setattr(batch_datasampes[i], key, PixelData(data=value[i])) + + _predictions = [] + for pred in batch_datasampes: + if isinstance(pred, BaseDataElement): + _predictions.append(pred.to_dict()) + else: + _predictions.append(pred) + return _predictions + + def test_evaluate(self): + """Test using the metric in the same way as Evalutor.""" + + data_batch = self._demo_mm_inputs() + predictions = self._demo_mm_model_output() + + iou_metric = IoUMetric(metrics=['mIoU']) + iou_metric.dataset_meta = dict( + classes=['wall', 'building', 'sky', 'floor', 'tree'], + label_map=dict(), + reduce_zero_label=False) + iou_metric.process(data_batch, predictions) + res = iou_metric.evaluate(6) + self.assertIsInstance(res, dict)