mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Merge branch 'linfangjian/ioumetrics' into 'refactor_dev'
[Refactor] Refactor IoU metrics See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!16
This commit is contained in:
commit
4ce9c01e33
4
mmseg/metrics/__init__.py
Normal file
4
mmseg/metrics/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .iou_metric import IoUMetric
|
||||
|
||||
__all__ = ['IoUMetric']
|
278
mmseg/metrics/iou_metric.py
Normal file
278
mmseg/metrics/iou_metric.py
Normal file
@ -0,0 +1,278 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from prettytable import PrettyTable
|
||||
|
||||
from mmseg.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class IoUMetric(BaseMetric):
|
||||
"""IoU evaluation metric.
|
||||
|
||||
Args:
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
Default: 255.
|
||||
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
|
||||
nan_to_num (int, optional): If specified, NaN values will be replaced
|
||||
by the numbers defined by the user. Default: None.
|
||||
beta (int): Determines the weight of recall in the combined score.
|
||||
Default: 1.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ignore_index: int = 255,
|
||||
metrics: List[str] = ['mIoU'],
|
||||
nan_to_num: Optional[int] = None,
|
||||
beta: int = 1,
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
||||
self.ignore_index = ignore_index
|
||||
self.metrics = metrics
|
||||
self.nan_to_num = nan_to_num
|
||||
self.beta = beta
|
||||
|
||||
def process(self, data_batch: Sequence[dict],
|
||||
predictions: Sequence[dict]) -> None:
|
||||
"""Process one batch of data and predictions.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
num_classes = len(self.dataset_meta['classes'])
|
||||
label_map = self.dataset_meta['label_map']
|
||||
reduce_zero_label = self.dataset_meta['reduce_zero_label']
|
||||
for data, pred in zip(data_batch, predictions):
|
||||
label = data['data_sample']['gt_sem_seg']['data'][0].cpu().numpy()
|
||||
pred_label = pred['pred_sem_seg']['data'][0]
|
||||
self.results.append(
|
||||
self.intersect_and_union(pred_label, label, num_classes,
|
||||
self.ignore_index, label_map,
|
||||
reduce_zero_label))
|
||||
|
||||
def compute_metrics(self, results: list) -> Dict[str, float]:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: The computed metrics. The keys are the names of
|
||||
the metrics, and the values are corresponding results. The key
|
||||
mainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision,
|
||||
mRecall.
|
||||
"""
|
||||
logger: MMLogger = MMLogger.get_current_instance()
|
||||
|
||||
# convert list of tuples to tuple of lists, e.g.
|
||||
# [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to
|
||||
# ([A_1, ..., A_n], ..., [D_1, ..., D_n])
|
||||
results = tuple(zip(*results))
|
||||
assert len(results) == 4
|
||||
|
||||
total_area_intersect = sum(results[0])
|
||||
total_area_union = sum(results[1])
|
||||
total_area_pred_label = sum(results[2])
|
||||
total_area_label = sum(results[3])
|
||||
ret_metrics = self.total_area_to_metrics(
|
||||
total_area_intersect, total_area_union, total_area_pred_label,
|
||||
total_area_label, self.metrics, self.nan_to_num, self.beta)
|
||||
|
||||
class_names = self.dataset_meta['classes']
|
||||
|
||||
# summary table
|
||||
ret_metrics_summary = OrderedDict({
|
||||
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
|
||||
for ret_metric, ret_metric_value in ret_metrics.items()
|
||||
})
|
||||
metrics = dict()
|
||||
for key, val in ret_metrics_summary.items():
|
||||
if key == 'aAcc':
|
||||
metrics[key] = val
|
||||
else:
|
||||
metrics['m' + key] = val
|
||||
|
||||
# each class table
|
||||
ret_metrics.pop('aAcc', None)
|
||||
ret_metrics_class = OrderedDict({
|
||||
ret_metric: np.round(ret_metric_value * 100, 2)
|
||||
for ret_metric, ret_metric_value in ret_metrics.items()
|
||||
})
|
||||
ret_metrics_class.update({'Class': class_names})
|
||||
ret_metrics_class.move_to_end('Class', last=False)
|
||||
class_table_data = PrettyTable()
|
||||
for key, val in ret_metrics_class.items():
|
||||
class_table_data.add_column(key, val)
|
||||
|
||||
print_log('per class results:', logger)
|
||||
print_log('\n' + class_table_data.get_string(), logger=logger)
|
||||
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def intersect_and_union(pred_label: Union[np.ndarray, str],
|
||||
label: Union[np.ndarray, str],
|
||||
num_classes: int,
|
||||
ignore_index: int,
|
||||
label_map: dict = dict(),
|
||||
reduce_zero_label: bool = False):
|
||||
"""Calculate intersection and Union.
|
||||
|
||||
Args:
|
||||
pred_label (ndarray | str): Prediction segmentation map
|
||||
or predict result filename. The shape is (H, W).
|
||||
label (ndarray | str): Ground truth segmentation map
|
||||
or label filename. The shape is (H, W).
|
||||
num_classes (int): Number of categories.
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
label_map (dict): Mapping old labels to new labels. The parameter
|
||||
will work only when label is str. Default: dict().
|
||||
reduce_zero_label (bool): Whether ignore zero label. The parameter
|
||||
will work only when label is str. Default: False.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The intersection of prediction and ground truth
|
||||
histogram on all classes.
|
||||
torch.Tensor: The union of prediction and ground truth histogram on
|
||||
all classes.
|
||||
torch.Tensor: The prediction histogram on all classes.
|
||||
torch.Tensor: The ground truth histogram on all classes.
|
||||
"""
|
||||
|
||||
if isinstance(pred_label, str):
|
||||
pred_label = torch.from_numpy(np.load(pred_label))
|
||||
else:
|
||||
pred_label = torch.from_numpy((pred_label))
|
||||
|
||||
if isinstance(label, str):
|
||||
label = torch.from_numpy(
|
||||
mmcv.imread(label, flag='unchanged', backend='pillow'))
|
||||
else:
|
||||
label = torch.from_numpy(label)
|
||||
|
||||
if label_map is not None:
|
||||
for old_id, new_id in label_map.items():
|
||||
label[label == old_id] = new_id
|
||||
if reduce_zero_label:
|
||||
label[label == 0] = 255
|
||||
label = label - 1
|
||||
label[label == 254] = 255
|
||||
|
||||
mask = (label != ignore_index)
|
||||
pred_label = pred_label[mask]
|
||||
label = label[mask]
|
||||
|
||||
intersect = pred_label[pred_label == label]
|
||||
area_intersect = torch.histc(
|
||||
intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
||||
area_pred_label = torch.histc(
|
||||
pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
||||
area_label = torch.histc(
|
||||
label.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
||||
area_union = area_pred_label + area_label - area_intersect
|
||||
return area_intersect, area_union, area_pred_label, area_label
|
||||
|
||||
@staticmethod
|
||||
def total_area_to_metrics(total_area_intersect: np.ndarray,
|
||||
total_area_union: np.ndarray,
|
||||
total_area_pred_label: np.ndarray,
|
||||
total_area_label: np.ndarray,
|
||||
metrics: List[str] = ['mIoU'],
|
||||
nan_to_num: Optional[int] = None,
|
||||
beta: int = 1):
|
||||
"""Calculate evaluation metrics
|
||||
Args:
|
||||
total_area_intersect (ndarray): The intersection of prediction and
|
||||
ground truth histogram on all classes.
|
||||
total_area_union (ndarray): The union of prediction and ground
|
||||
truth histogram on all classes.
|
||||
total_area_pred_label (ndarray): The prediction histogram on all
|
||||
classes.
|
||||
total_area_label (ndarray): The ground truth histogram on
|
||||
all classes.
|
||||
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and
|
||||
'mDice'.
|
||||
nan_to_num (int, optional): If specified, NaN values will be
|
||||
replaced by the numbers defined by the user. Default: None.
|
||||
beta (int): Determines the weight of recall in the combined score.
|
||||
Default: 1.
|
||||
Returns:
|
||||
Dict[str, ndarray]: per category evaluation metrics,
|
||||
shape (num_classes, ).
|
||||
"""
|
||||
|
||||
def f_score(precision, recall, beta=1):
|
||||
"""calculate the f-score value.
|
||||
|
||||
Args:
|
||||
precision (float | torch.Tensor): The precision value.
|
||||
recall (float | torch.Tensor): The recall value.
|
||||
beta (int): Determines the weight of recall in the combined
|
||||
score. Default: 1.
|
||||
|
||||
Returns:
|
||||
[torch.tensor]: The f-score value.
|
||||
"""
|
||||
score = (1 + beta**2) * (precision * recall) / (
|
||||
(beta**2 * precision) + recall)
|
||||
return score
|
||||
|
||||
if isinstance(metrics, str):
|
||||
metrics = [metrics]
|
||||
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
|
||||
if not set(metrics).issubset(set(allowed_metrics)):
|
||||
raise KeyError('metrics {} is not supported'.format(metrics))
|
||||
|
||||
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
||||
ret_metrics = OrderedDict({'aAcc': all_acc})
|
||||
for metric in metrics:
|
||||
if metric == 'mIoU':
|
||||
iou = total_area_intersect / total_area_union
|
||||
acc = total_area_intersect / total_area_label
|
||||
ret_metrics['IoU'] = iou
|
||||
ret_metrics['Acc'] = acc
|
||||
elif metric == 'mDice':
|
||||
dice = 2 * total_area_intersect / (
|
||||
total_area_pred_label + total_area_label)
|
||||
acc = total_area_intersect / total_area_label
|
||||
ret_metrics['Dice'] = dice
|
||||
ret_metrics['Acc'] = acc
|
||||
elif metric == 'mFscore':
|
||||
precision = total_area_intersect / total_area_pred_label
|
||||
recall = total_area_intersect / total_area_label
|
||||
f_value = torch.tensor([
|
||||
f_score(x[0], x[1], beta) for x in zip(precision, recall)
|
||||
])
|
||||
ret_metrics['Fscore'] = f_value
|
||||
ret_metrics['Precision'] = precision
|
||||
ret_metrics['Recall'] = recall
|
||||
|
||||
ret_metrics = {
|
||||
metric: value.numpy()
|
||||
for metric, value in ret_metrics.items()
|
||||
}
|
||||
if nan_to_num is not None:
|
||||
ret_metrics = OrderedDict({
|
||||
metric: np.nan_to_num(metric_value, nan=nan_to_num)
|
||||
for metric, metric_value in ret_metrics.items()
|
||||
})
|
||||
return ret_metrics
|
@ -72,6 +72,7 @@ def register_all_modules(init_default_scope: bool = True) -> None:
|
||||
import mmseg.core # noqa: F401,F403
|
||||
import mmseg.datasets # noqa: F401,F403
|
||||
import mmseg.datasets.pipelines # noqa: F401,F403
|
||||
import mmseg.metrics # noqa: F401,F403
|
||||
import mmseg.models # noqa: F401,F403
|
||||
|
||||
if init_default_scope:
|
||||
|
99
tests/test_metrics/test_iou_metric.py
Normal file
99
tests/test_metrics/test_iou_metric.py
Normal file
@ -0,0 +1,99 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.data import BaseDataElement, PixelData
|
||||
|
||||
from mmseg.core import SegDataSample
|
||||
from mmseg.metrics import IoUMetric
|
||||
|
||||
|
||||
class TestIoUMetric(TestCase):
|
||||
|
||||
def _demo_mm_inputs(self,
|
||||
batch_size=2,
|
||||
image_shapes=(3, 64, 64),
|
||||
num_classes=5):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
batch_size (int): batch size. Default to 2.
|
||||
image_shapes (List[tuple], Optional): image shape.
|
||||
Default to (3, 64, 64)
|
||||
num_classes (int): number of different classes.
|
||||
Default to 5.
|
||||
"""
|
||||
if isinstance(image_shapes, list):
|
||||
assert len(image_shapes) == batch_size
|
||||
else:
|
||||
image_shapes = [image_shapes] * batch_size
|
||||
|
||||
packed_inputs = []
|
||||
for idx in range(batch_size):
|
||||
image_shape = image_shapes[idx]
|
||||
_, h, w = image_shape
|
||||
|
||||
mm_inputs = dict()
|
||||
data_sample = SegDataSample()
|
||||
gt_semantic_seg = np.random.randint(
|
||||
0, num_classes, (1, h, w), dtype=np.uint8)
|
||||
gt_semantic_seg = torch.LongTensor(gt_semantic_seg)
|
||||
gt_sem_seg_data = dict(data=gt_semantic_seg)
|
||||
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
mm_inputs['data_sample'] = data_sample.to_dict()
|
||||
packed_inputs.append(mm_inputs)
|
||||
|
||||
return packed_inputs
|
||||
|
||||
def _demo_mm_model_output(self,
|
||||
batch_size=2,
|
||||
image_shapes=(3, 64, 64),
|
||||
num_classes=5):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
batch_size (int): batch size. Default to 2.
|
||||
image_shapes (List[tuple], Optional): image shape.
|
||||
Default to (3, 64, 64)
|
||||
num_classes (int): number of different classes.
|
||||
Default to 5.
|
||||
"""
|
||||
results_dict = dict()
|
||||
_, h, w = image_shapes
|
||||
seg_logit = torch.randn(batch_size, num_classes, h, w)
|
||||
results_dict['seg_logits'] = seg_logit
|
||||
seg_pred = np.random.randint(
|
||||
0, num_classes, (batch_size, h, w), dtype=np.uint8)
|
||||
results_dict['pred_sem_seg'] = seg_pred
|
||||
|
||||
batch_datasampes = [
|
||||
SegDataSample()
|
||||
for _ in range(results_dict['pred_sem_seg'].shape[0])
|
||||
]
|
||||
for key, value in results_dict.items():
|
||||
for i in range(value.shape[0]):
|
||||
setattr(batch_datasampes[i], key, PixelData(data=value[i]))
|
||||
|
||||
_predictions = []
|
||||
for pred in batch_datasampes:
|
||||
if isinstance(pred, BaseDataElement):
|
||||
_predictions.append(pred.to_dict())
|
||||
else:
|
||||
_predictions.append(pred)
|
||||
return _predictions
|
||||
|
||||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
|
||||
data_batch = self._demo_mm_inputs()
|
||||
predictions = self._demo_mm_model_output()
|
||||
|
||||
iou_metric = IoUMetric(metrics=['mIoU'])
|
||||
iou_metric.dataset_meta = dict(
|
||||
classes=['wall', 'building', 'sky', 'floor', 'tree'],
|
||||
label_map=dict(),
|
||||
reduce_zero_label=False)
|
||||
iou_metric.process(data_batch, predictions)
|
||||
res = iou_metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
Loading…
x
Reference in New Issue
Block a user