EasyCV/easycv/core/evaluation/segmentation_eval.py

252 lines
9.6 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from collections import OrderedDict
import numpy as np
import torch
from prettytable import PrettyTable
from easycv.utils.logger import print_log
from .base_evaluator import Evaluator
from .builder import EVALUATORS
from .metric_registry import METRICS
from .metrics import f_score
_ALLOWED_METRICS = ['mIoU', 'mDice', 'mFscore']
@EVALUATORS.register_module
class SegmentationEvaluator(Evaluator):
def __init__(self, classes, dataset_name=None, metric_names=['mIoU']):
"""
Args:
classes (tuple | list): classes name list
dataset_name (str): dataset name
metric_names (List[str]): metric names this evaluator will return
"""
super().__init__(dataset_name, metric_names)
self.classes = classes
if isinstance(self._metric_names, str):
self._metric_names = [self._metric_names]
if not set(self._metric_names).issubset(set(_ALLOWED_METRICS)):
raise KeyError('metric {} is not supported'.format(
self._metric_names))
def _evaluate_impl(self, prediction_dict, groundtruth_dict):
"""
Args:
prediction_dict: A dict of k-v pair, each v is a list of
tensor or numpy array for segmentation result. A dictionary containing
seg_pred: List of length number of test images, integer numpy array of shape
[width * height].
groundtruth_dict: A dict of k-v pair, each v is a list of
tensor or numpy array for groundtruth info. A dictionary containing
gt_seg_maps: List of length number of test images, integer numpy array of shape
[width * height].
Return:
dict, each key is metric_name, value is metric value
"""
results = prediction_dict['seg_pred']
gt_seg_maps = groundtruth_dict['gt_seg_maps']
ret_metrics = eval_metrics(
results,
gt_seg_maps,
len(self.classes),
self._metric_names,
)
return self._format_results(ret_metrics)
def _format_results(self, ret_metrics):
eval_results = {}
# summary table
ret_metrics_summary = OrderedDict({
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})
# each class table
ret_metrics.pop('aAcc', None)
ret_metrics_class = OrderedDict({
ret_metric: np.round(ret_metric_value * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})
ret_metrics_class.update({'Class': self.classes})
ret_metrics_class.move_to_end('Class', last=False)
# for logger
class_table_data = PrettyTable()
for key, val in ret_metrics_class.items():
class_table_data.add_column(key, val)
summary_table_data = PrettyTable()
for key, val in ret_metrics_summary.items():
if key == 'aAcc':
summary_table_data.add_column(key, [val])
else:
summary_table_data.add_column('m' + key, [val])
print_log('per class results:')
print_log('\n' + class_table_data.get_string())
print_log('Summary:')
print_log('\n' + summary_table_data.get_string())
# each metric dict
for key, value in ret_metrics_summary.items():
if key == 'aAcc':
eval_results[key] = value / 100.0
else:
eval_results['m' + key] = value / 100.0
ret_metrics_class.pop('Class', None)
for key, value in ret_metrics_class.items():
eval_results.update({
key + '.' + str(name): value[idx] / 100.0
for idx, name in enumerate(self.classes)
})
return eval_results
METRICS.register_default_best_metric(SegmentationEvaluator, 'mIoU', 'max')
def intersect_and_union(
pred_label,
label,
num_classes,
):
"""Calculate intersection and Union.
Args:
pred_label (ndarray): Prediction segmentation map.
label (ndarray): Ground truth segmentation map.
num_classes (int): Number of categories.
Returns:
torch.Tensor: The intersection of prediction and ground truth
histogram on all classes.
torch.Tensor: The union of prediction and ground truth histogram on
all classes.
torch.Tensor: The prediction histogram on all classes.
torch.Tensor: The ground truth histogram on all classes.
"""
pred_label = torch.from_numpy((pred_label))
label = torch.from_numpy(label)
intersect = pred_label[pred_label == label]
area_intersect = torch.histc(
intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
area_pred_label = torch.histc(
pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
area_label = torch.histc(
label.float(), bins=(num_classes), min=0, max=num_classes - 1)
area_union = area_pred_label + area_label - area_intersect
return area_intersect, area_union, area_pred_label, area_label
def eval_metrics(results,
gt_seg_maps,
num_classes,
metrics=['mIoU'],
nan_to_num=None,
beta=1):
"""Calculate evaluation metrics
Args:
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
num_classes (int): Number of categories.
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evaluation metrics, shape (num_classes, ).
"""
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
for result, gt_seg_map in zip(results, gt_seg_maps):
area_intersect, area_union, area_pred_label, area_label = \
intersect_and_union(
result, gt_seg_map, num_classes)
total_area_intersect += area_intersect
total_area_union += area_union
total_area_pred_label += area_pred_label
total_area_label += area_label
ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union,
total_area_pred_label,
total_area_label, metrics, nan_to_num,
beta)
return ret_metrics
def total_area_to_metrics(total_area_intersect,
total_area_union,
total_area_pred_label,
total_area_label,
metrics=['mIoU'],
nan_to_num=None,
beta=1):
"""Calculate evaluation metrics
Args:
total_area_intersect (ndarray): The intersection of prediction and
ground truth histogram on all classes.
total_area_union (ndarray): The union of prediction and ground truth
histogram on all classes.
total_area_pred_label (ndarray): The prediction histogram on all
classes.
total_area_label (ndarray): The ground truth histogram on all classes.
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evaluation metrics, shape (num_classes, ).
"""
if isinstance(metrics, str):
metrics = [metrics]
if not set(metrics).issubset(set(_ALLOWED_METRICS)):
raise KeyError('metrics {} is not supported'.format(metrics))
all_acc = total_area_intersect.sum() / total_area_label.sum()
ret_metrics = OrderedDict({'aAcc': all_acc})
for metric in metrics:
if metric == 'mIoU':
iou = total_area_intersect / total_area_union
acc = total_area_intersect / total_area_label
ret_metrics['IoU'] = iou
ret_metrics['Acc'] = acc
elif metric == 'mDice':
dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label)
acc = total_area_intersect / total_area_label
ret_metrics['Dice'] = dice
ret_metrics['Acc'] = acc
elif metric == 'mFscore':
precision = total_area_intersect / total_area_pred_label
recall = total_area_intersect / total_area_label
f_value = torch.tensor(
[f_score(x[0], x[1], beta) for x in zip(precision, recall)])
ret_metrics['Fscore'] = f_value
ret_metrics['Precision'] = precision
ret_metrics['Recall'] = recall
ret_metrics = {
metric: value.numpy()
for metric, value in ret_metrics.items()
}
if nan_to_num is not None:
ret_metrics = OrderedDict({
metric: np.nan_to_num(metric_value, nan=nan_to_num)
for metric, metric_value in ret_metrics.items()
})
return ret_metrics