mirror of https://github.com/alibaba/EasyCV.git
252 lines
9.6 KiB
Python
252 lines
9.6 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
from collections import OrderedDict
|
|
|
|
import numpy as np
|
|
import torch
|
|
from prettytable import PrettyTable
|
|
|
|
from easycv.utils.logger import print_log
|
|
from .base_evaluator import Evaluator
|
|
from .builder import EVALUATORS
|
|
from .metric_registry import METRICS
|
|
from .metrics import f_score
|
|
|
|
_ALLOWED_METRICS = ['mIoU', 'mDice', 'mFscore']
|
|
|
|
|
|
@EVALUATORS.register_module
|
|
class SegmentationEvaluator(Evaluator):
|
|
|
|
def __init__(self, classes, dataset_name=None, metric_names=['mIoU']):
|
|
"""
|
|
Args:
|
|
classes (tuple | list): classes name list
|
|
dataset_name (str): dataset name
|
|
metric_names (List[str]): metric names this evaluator will return
|
|
"""
|
|
super().__init__(dataset_name, metric_names)
|
|
|
|
self.classes = classes
|
|
if isinstance(self._metric_names, str):
|
|
self._metric_names = [self._metric_names]
|
|
if not set(self._metric_names).issubset(set(_ALLOWED_METRICS)):
|
|
raise KeyError('metric {} is not supported'.format(
|
|
self._metric_names))
|
|
|
|
def _evaluate_impl(self, prediction_dict, groundtruth_dict):
|
|
"""
|
|
Args:
|
|
prediction_dict: A dict of k-v pair, each v is a list of
|
|
tensor or numpy array for segmentation result. A dictionary containing
|
|
seg_pred: List of length number of test images, integer numpy array of shape
|
|
[width * height].
|
|
groundtruth_dict: A dict of k-v pair, each v is a list of
|
|
tensor or numpy array for groundtruth info. A dictionary containing
|
|
gt_seg_maps: List of length number of test images, integer numpy array of shape
|
|
[width * height].
|
|
Return:
|
|
dict, each key is metric_name, value is metric value
|
|
"""
|
|
results = prediction_dict['seg_pred']
|
|
gt_seg_maps = groundtruth_dict['gt_seg_maps']
|
|
|
|
ret_metrics = eval_metrics(
|
|
results,
|
|
gt_seg_maps,
|
|
len(self.classes),
|
|
self._metric_names,
|
|
)
|
|
return self._format_results(ret_metrics)
|
|
|
|
def _format_results(self, ret_metrics):
|
|
eval_results = {}
|
|
|
|
# summary table
|
|
ret_metrics_summary = OrderedDict({
|
|
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
|
|
for ret_metric, ret_metric_value in ret_metrics.items()
|
|
})
|
|
|
|
# each class table
|
|
ret_metrics.pop('aAcc', None)
|
|
ret_metrics_class = OrderedDict({
|
|
ret_metric: np.round(ret_metric_value * 100, 2)
|
|
for ret_metric, ret_metric_value in ret_metrics.items()
|
|
})
|
|
ret_metrics_class.update({'Class': self.classes})
|
|
ret_metrics_class.move_to_end('Class', last=False)
|
|
|
|
# for logger
|
|
class_table_data = PrettyTable()
|
|
for key, val in ret_metrics_class.items():
|
|
class_table_data.add_column(key, val)
|
|
|
|
summary_table_data = PrettyTable()
|
|
for key, val in ret_metrics_summary.items():
|
|
if key == 'aAcc':
|
|
summary_table_data.add_column(key, [val])
|
|
else:
|
|
summary_table_data.add_column('m' + key, [val])
|
|
|
|
print_log('per class results:')
|
|
print_log('\n' + class_table_data.get_string())
|
|
print_log('Summary:')
|
|
print_log('\n' + summary_table_data.get_string())
|
|
|
|
# each metric dict
|
|
for key, value in ret_metrics_summary.items():
|
|
if key == 'aAcc':
|
|
eval_results[key] = value / 100.0
|
|
else:
|
|
eval_results['m' + key] = value / 100.0
|
|
|
|
ret_metrics_class.pop('Class', None)
|
|
for key, value in ret_metrics_class.items():
|
|
eval_results.update({
|
|
key + '.' + str(name): value[idx] / 100.0
|
|
for idx, name in enumerate(self.classes)
|
|
})
|
|
|
|
return eval_results
|
|
|
|
|
|
METRICS.register_default_best_metric(SegmentationEvaluator, 'mIoU', 'max')
|
|
|
|
|
|
def intersect_and_union(
|
|
pred_label,
|
|
label,
|
|
num_classes,
|
|
):
|
|
"""Calculate intersection and Union.
|
|
|
|
Args:
|
|
pred_label (ndarray): Prediction segmentation map.
|
|
label (ndarray): Ground truth segmentation map.
|
|
num_classes (int): Number of categories.
|
|
|
|
Returns:
|
|
torch.Tensor: The intersection of prediction and ground truth
|
|
histogram on all classes.
|
|
torch.Tensor: The union of prediction and ground truth histogram on
|
|
all classes.
|
|
torch.Tensor: The prediction histogram on all classes.
|
|
torch.Tensor: The ground truth histogram on all classes.
|
|
"""
|
|
pred_label = torch.from_numpy((pred_label))
|
|
label = torch.from_numpy(label)
|
|
|
|
intersect = pred_label[pred_label == label]
|
|
area_intersect = torch.histc(
|
|
intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
|
area_pred_label = torch.histc(
|
|
pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
|
area_label = torch.histc(
|
|
label.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
|
area_union = area_pred_label + area_label - area_intersect
|
|
return area_intersect, area_union, area_pred_label, area_label
|
|
|
|
|
|
def eval_metrics(results,
|
|
gt_seg_maps,
|
|
num_classes,
|
|
metrics=['mIoU'],
|
|
nan_to_num=None,
|
|
beta=1):
|
|
"""Calculate evaluation metrics
|
|
Args:
|
|
results (list[ndarray]): List of prediction segmentation maps.
|
|
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
|
|
num_classes (int): Number of categories.
|
|
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
|
|
nan_to_num (int, optional): If specified, NaN values will be replaced
|
|
by the numbers defined by the user. Default: None.
|
|
Returns:
|
|
float: Overall accuracy on all images.
|
|
ndarray: Per category accuracy, shape (num_classes, ).
|
|
ndarray: Per category evaluation metrics, shape (num_classes, ).
|
|
"""
|
|
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
|
|
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
|
|
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
|
|
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
|
|
for result, gt_seg_map in zip(results, gt_seg_maps):
|
|
area_intersect, area_union, area_pred_label, area_label = \
|
|
intersect_and_union(
|
|
result, gt_seg_map, num_classes)
|
|
total_area_intersect += area_intersect
|
|
total_area_union += area_union
|
|
total_area_pred_label += area_pred_label
|
|
total_area_label += area_label
|
|
|
|
ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union,
|
|
total_area_pred_label,
|
|
total_area_label, metrics, nan_to_num,
|
|
beta)
|
|
|
|
return ret_metrics
|
|
|
|
|
|
def total_area_to_metrics(total_area_intersect,
|
|
total_area_union,
|
|
total_area_pred_label,
|
|
total_area_label,
|
|
metrics=['mIoU'],
|
|
nan_to_num=None,
|
|
beta=1):
|
|
"""Calculate evaluation metrics
|
|
Args:
|
|
total_area_intersect (ndarray): The intersection of prediction and
|
|
ground truth histogram on all classes.
|
|
total_area_union (ndarray): The union of prediction and ground truth
|
|
histogram on all classes.
|
|
total_area_pred_label (ndarray): The prediction histogram on all
|
|
classes.
|
|
total_area_label (ndarray): The ground truth histogram on all classes.
|
|
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
|
|
nan_to_num (int, optional): If specified, NaN values will be replaced
|
|
by the numbers defined by the user. Default: None.
|
|
Returns:
|
|
float: Overall accuracy on all images.
|
|
ndarray: Per category accuracy, shape (num_classes, ).
|
|
ndarray: Per category evaluation metrics, shape (num_classes, ).
|
|
"""
|
|
if isinstance(metrics, str):
|
|
metrics = [metrics]
|
|
if not set(metrics).issubset(set(_ALLOWED_METRICS)):
|
|
raise KeyError('metrics {} is not supported'.format(metrics))
|
|
|
|
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
|
ret_metrics = OrderedDict({'aAcc': all_acc})
|
|
for metric in metrics:
|
|
if metric == 'mIoU':
|
|
iou = total_area_intersect / total_area_union
|
|
acc = total_area_intersect / total_area_label
|
|
ret_metrics['IoU'] = iou
|
|
ret_metrics['Acc'] = acc
|
|
elif metric == 'mDice':
|
|
dice = 2 * total_area_intersect / (
|
|
total_area_pred_label + total_area_label)
|
|
acc = total_area_intersect / total_area_label
|
|
ret_metrics['Dice'] = dice
|
|
ret_metrics['Acc'] = acc
|
|
elif metric == 'mFscore':
|
|
precision = total_area_intersect / total_area_pred_label
|
|
recall = total_area_intersect / total_area_label
|
|
f_value = torch.tensor(
|
|
[f_score(x[0], x[1], beta) for x in zip(precision, recall)])
|
|
ret_metrics['Fscore'] = f_value
|
|
ret_metrics['Precision'] = precision
|
|
ret_metrics['Recall'] = recall
|
|
|
|
ret_metrics = {
|
|
metric: value.numpy()
|
|
for metric, value in ret_metrics.items()
|
|
}
|
|
if nan_to_num is not None:
|
|
ret_metrics = OrderedDict({
|
|
metric: np.nan_to_num(metric_value, nan=nan_to_num)
|
|
for metric, metric_value in ret_metrics.items()
|
|
})
|
|
return ret_metrics
|