mmsegmentation/mmseg/evaluation/metrics/iou_metric.py

287 lines
12 KiB
Python
Raw Normal View History

2022-06-02 22:15:28 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
2022-06-02 22:15:28 +08:00
from collections import OrderedDict
2022-06-15 16:18:26 +08:00
from typing import Dict, List, Optional, Sequence
2022-06-02 22:15:28 +08:00
import numpy as np
import torch
from mmengine.dist import is_main_process
2022-06-02 22:15:28 +08:00
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger, print_log
from mmengine.utils import mkdir_or_exist
from PIL import Image
2022-06-02 22:15:28 +08:00
from prettytable import PrettyTable
from mmseg.registry import METRICS
@METRICS.register_module()
class IoUMetric(BaseMetric):
"""IoU evaluation metric.
Args:
ignore_index (int): Index that will be ignored in evaluation.
Default: 255.
iou_metrics (list[str] | str): Metrics to be calculated, the options
includes 'mIoU', 'mDice' and 'mFscore'.
2022-06-02 22:15:28 +08:00
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
beta (int): Determines the weight of recall in the combined score.
Default: 1.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
output_dir (str): The directory for output prediction. Defaults to
None.
format_only (bool): Only format result for results commit without
perform evaluation. It is useful when you want to save the result
to a specific format and submit it to the test server.
Defaults to False.
2022-06-02 22:15:28 +08:00
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
"""
def __init__(self,
ignore_index: int = 255,
iou_metrics: List[str] = ['mIoU'],
2022-06-02 22:15:28 +08:00
nan_to_num: Optional[int] = None,
beta: int = 1,
collect_device: str = 'cpu',
output_dir: Optional[str] = None,
format_only: bool = False,
prefix: Optional[str] = None,
**kwargs) -> None:
2022-06-02 22:15:28 +08:00
super().__init__(collect_device=collect_device, prefix=prefix)
self.ignore_index = ignore_index
self.metrics = iou_metrics
2022-06-02 22:15:28 +08:00
self.nan_to_num = nan_to_num
self.beta = beta
self.output_dir = output_dir
if self.output_dir and is_main_process():
mkdir_or_exist(self.output_dir)
self.format_only = format_only
2022-06-02 22:15:28 +08:00
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
"""Process one batch of data and data_samples.
2022-06-02 22:15:28 +08:00
The processed results should be stored in ``self.results``, which will
be used to compute the metrics when all batches have been processed.
2022-06-02 22:15:28 +08:00
Args:
data_batch (dict): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
2022-06-02 22:15:28 +08:00
"""
num_classes = len(self.dataset_meta['classes'])
for data_sample in data_samples:
pred_label = data_sample['pred_sem_seg']['data'].squeeze()
# format_only always for test dataset without ground truth
if not self.format_only:
label = data_sample['gt_sem_seg']['data'].squeeze().to(
pred_label)
self.results.append(
self.intersect_and_union(pred_label, label, num_classes,
self.ignore_index))
# format_result
if self.output_dir is not None:
basename = osp.splitext(osp.basename(
data_sample['img_path']))[0]
png_filename = osp.abspath(
osp.join(self.output_dir, f'{basename}.png'))
output_mask = pred_label.cpu().numpy()
# The index range of official ADE20k dataset is from 0 to 150.
# But the index range of output is from 0 to 149.
# That is because we set reduce_zero_label=True.
if data_sample.get('reduce_zero_label', False):
output_mask = output_mask + 1
output = Image.fromarray(output_mask.astype(np.uint8))
output.save(png_filename)
2022-06-02 22:15:28 +08:00
def compute_metrics(self, results: list) -> Dict[str, float]:
"""Compute the metrics from processed results.
Args:
results (list): The processed results of each batch.
Returns:
Dict[str, float]: The computed metrics. The keys are the names of
the metrics, and the values are corresponding results. The key
mainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision,
mRecall.
"""
logger: MMLogger = MMLogger.get_current_instance()
if self.format_only:
logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
return OrderedDict()
2022-06-02 22:15:28 +08:00
# convert list of tuples to tuple of lists, e.g.
# [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to
# ([A_1, ..., A_n], ..., [D_1, ..., D_n])
results = tuple(zip(*results))
assert len(results) == 4
total_area_intersect = sum(results[0])
total_area_union = sum(results[1])
total_area_pred_label = sum(results[2])
total_area_label = sum(results[3])
ret_metrics = self.total_area_to_metrics(
total_area_intersect, total_area_union, total_area_pred_label,
total_area_label, self.metrics, self.nan_to_num, self.beta)
class_names = self.dataset_meta['classes']
# summary table
ret_metrics_summary = OrderedDict({
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})
metrics = dict()
for key, val in ret_metrics_summary.items():
if key == 'aAcc':
metrics[key] = val
else:
metrics['m' + key] = val
# each class table
ret_metrics.pop('aAcc', None)
ret_metrics_class = OrderedDict({
ret_metric: np.round(ret_metric_value * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})
ret_metrics_class.update({'Class': class_names})
ret_metrics_class.move_to_end('Class', last=False)
class_table_data = PrettyTable()
for key, val in ret_metrics_class.items():
class_table_data.add_column(key, val)
print_log('per class results:', logger)
print_log('\n' + class_table_data.get_string(), logger=logger)
return metrics
@staticmethod
2022-06-15 16:18:26 +08:00
def intersect_and_union(pred_label: torch.tensor, label: torch.tensor,
num_classes: int, ignore_index: int):
2022-07-05 20:43:33 +08:00
"""Calculate Intersection and Union.
2022-06-02 22:15:28 +08:00
Args:
2022-06-15 16:18:26 +08:00
pred_label (torch.tensor): Prediction segmentation map
2022-06-02 22:15:28 +08:00
or predict result filename. The shape is (H, W).
2022-06-15 16:18:26 +08:00
label (torch.tensor): Ground truth segmentation map
2022-06-02 22:15:28 +08:00
or label filename. The shape is (H, W).
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
Returns:
torch.Tensor: The intersection of prediction and ground truth
histogram on all classes.
torch.Tensor: The union of prediction and ground truth histogram on
all classes.
torch.Tensor: The prediction histogram on all classes.
torch.Tensor: The ground truth histogram on all classes.
"""
mask = (label != ignore_index)
pred_label = pred_label[mask]
label = label[mask]
intersect = pred_label[pred_label == label]
area_intersect = torch.histc(
intersect.float(), bins=(num_classes), min=0,
max=num_classes - 1).cpu()
2022-06-02 22:15:28 +08:00
area_pred_label = torch.histc(
pred_label.float(), bins=(num_classes), min=0,
max=num_classes - 1).cpu()
2022-06-02 22:15:28 +08:00
area_label = torch.histc(
label.float(), bins=(num_classes), min=0,
max=num_classes - 1).cpu()
2022-06-02 22:15:28 +08:00
area_union = area_pred_label + area_label - area_intersect
return area_intersect, area_union, area_pred_label, area_label
@staticmethod
def total_area_to_metrics(total_area_intersect: np.ndarray,
total_area_union: np.ndarray,
total_area_pred_label: np.ndarray,
total_area_label: np.ndarray,
metrics: List[str] = ['mIoU'],
nan_to_num: Optional[int] = None,
beta: int = 1):
"""Calculate evaluation metrics
Args:
2022-07-05 20:43:33 +08:00
total_area_intersect (np.ndarray): The intersection of prediction
and ground truth histogram on all classes.
total_area_union (np.ndarray): The union of prediction and ground
2022-06-02 22:15:28 +08:00
truth histogram on all classes.
2022-07-05 20:43:33 +08:00
total_area_pred_label (np.ndarray): The prediction histogram on
2022-06-02 22:15:28 +08:00
all classes.
2022-07-05 20:43:33 +08:00
total_area_label (np.ndarray): The ground truth histogram on
all classes.
metrics (List[str] | str): Metrics to be evaluated, 'mIoU' and
2022-06-02 22:15:28 +08:00
'mDice'.
nan_to_num (int, optional): If specified, NaN values will be
replaced by the numbers defined by the user. Default: None.
beta (int): Determines the weight of recall in the combined score.
Default: 1.
Returns:
2022-07-05 20:43:33 +08:00
Dict[str, np.ndarray]: per category evaluation metrics,
2022-06-02 22:15:28 +08:00
shape (num_classes, ).
"""
def f_score(precision, recall, beta=1):
"""calculate the f-score value.
Args:
precision (float | torch.Tensor): The precision value.
recall (float | torch.Tensor): The recall value.
beta (int): Determines the weight of recall in the combined
score. Default: 1.
Returns:
[torch.tensor]: The f-score value.
"""
score = (1 + beta**2) * (precision * recall) / (
(beta**2 * precision) + recall)
return score
if isinstance(metrics, str):
metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError(f'metrics {metrics} is not supported')
2022-06-02 22:15:28 +08:00
all_acc = total_area_intersect.sum() / total_area_label.sum()
ret_metrics = OrderedDict({'aAcc': all_acc})
for metric in metrics:
if metric == 'mIoU':
iou = total_area_intersect / total_area_union
acc = total_area_intersect / total_area_label
ret_metrics['IoU'] = iou
ret_metrics['Acc'] = acc
elif metric == 'mDice':
dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label)
acc = total_area_intersect / total_area_label
ret_metrics['Dice'] = dice
ret_metrics['Acc'] = acc
elif metric == 'mFscore':
precision = total_area_intersect / total_area_pred_label
recall = total_area_intersect / total_area_label
f_value = torch.tensor([
f_score(x[0], x[1], beta) for x in zip(precision, recall)
])
ret_metrics['Fscore'] = f_value
ret_metrics['Precision'] = precision
ret_metrics['Recall'] = recall
ret_metrics = {
metric: value.numpy()
for metric, value in ret_metrics.items()
}
if nan_to_num is not None:
ret_metrics = OrderedDict({
metric: np.nan_to_num(metric_value, nan=nan_to_num)
for metric, metric_value in ret_metrics.items()
})
return ret_metrics