# Copyright (c) Alibaba, Inc. and its affiliates. from collections import OrderedDict import numpy as np import torch from prettytable import PrettyTable from easycv.utils.logger import print_log from .base_evaluator import Evaluator from .builder import EVALUATORS from .metric_registry import METRICS from .metrics import f_score _ALLOWED_METRICS = ['mIoU', 'mDice', 'mFscore'] @EVALUATORS.register_module class SegmentationEvaluator(Evaluator): def __init__(self, classes, dataset_name=None, metric_names=['mIoU']): """ Args: classes (tuple | list): classes name list dataset_name (str): dataset name metric_names (List[str]): metric names this evaluator will return """ super().__init__(dataset_name, metric_names) self.classes = classes if isinstance(self._metric_names, str): self._metric_names = [self._metric_names] if not set(self._metric_names).issubset(set(_ALLOWED_METRICS)): raise KeyError('metric {} is not supported'.format( self._metric_names)) def _evaluate_impl(self, prediction_dict, groundtruth_dict): """ Args: prediction_dict: A dict of k-v pair, each v is a list of tensor or numpy array for segmentation result. A dictionary containing seg_pred: List of length number of test images, integer numpy array of shape [width * height]. groundtruth_dict: A dict of k-v pair, each v is a list of tensor or numpy array for groundtruth info. A dictionary containing gt_seg_maps: List of length number of test images, integer numpy array of shape [width * height]. Return: dict, each key is metric_name, value is metric value """ results = prediction_dict['seg_pred'] gt_seg_maps = groundtruth_dict['gt_seg_maps'] ret_metrics = eval_metrics( results, gt_seg_maps, len(self.classes), self._metric_names, ) return self._format_results(ret_metrics) def _format_results(self, ret_metrics): eval_results = {} # summary table ret_metrics_summary = OrderedDict({ ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) for ret_metric, ret_metric_value in ret_metrics.items() }) # each class table ret_metrics.pop('aAcc', None) ret_metrics_class = OrderedDict({ ret_metric: np.round(ret_metric_value * 100, 2) for ret_metric, ret_metric_value in ret_metrics.items() }) ret_metrics_class.update({'Class': self.classes}) ret_metrics_class.move_to_end('Class', last=False) # for logger class_table_data = PrettyTable() for key, val in ret_metrics_class.items(): class_table_data.add_column(key, val) summary_table_data = PrettyTable() for key, val in ret_metrics_summary.items(): if key == 'aAcc': summary_table_data.add_column(key, [val]) else: summary_table_data.add_column('m' + key, [val]) print_log('per class results:') print_log('\n' + class_table_data.get_string()) print_log('Summary:') print_log('\n' + summary_table_data.get_string()) # each metric dict for key, value in ret_metrics_summary.items(): if key == 'aAcc': eval_results[key] = value / 100.0 else: eval_results['m' + key] = value / 100.0 ret_metrics_class.pop('Class', None) for key, value in ret_metrics_class.items(): eval_results.update({ key + '.' + str(name): value[idx] / 100.0 for idx, name in enumerate(self.classes) }) return eval_results METRICS.register_default_best_metric(SegmentationEvaluator, 'mIoU', 'max') def intersect_and_union( pred_label, label, num_classes, ): """Calculate intersection and Union. Args: pred_label (ndarray): Prediction segmentation map. label (ndarray): Ground truth segmentation map. num_classes (int): Number of categories. Returns: torch.Tensor: The intersection of prediction and ground truth histogram on all classes. torch.Tensor: The union of prediction and ground truth histogram on all classes. torch.Tensor: The prediction histogram on all classes. torch.Tensor: The ground truth histogram on all classes. """ pred_label = torch.from_numpy((pred_label)) label = torch.from_numpy(label) intersect = pred_label[pred_label == label] area_intersect = torch.histc( intersect.float(), bins=(num_classes), min=0, max=num_classes - 1) area_pred_label = torch.histc( pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1) area_label = torch.histc( label.float(), bins=(num_classes), min=0, max=num_classes - 1) area_union = area_pred_label + area_label - area_intersect return area_intersect, area_union, area_pred_label, area_label def eval_metrics(results, gt_seg_maps, num_classes, metrics=['mIoU'], nan_to_num=None, beta=1): """Calculate evaluation metrics Args: results (list[ndarray]): List of prediction segmentation maps. gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. num_classes (int): Number of categories. metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. nan_to_num (int, optional): If specified, NaN values will be replaced by the numbers defined by the user. Default: None. Returns: float: Overall accuracy on all images. ndarray: Per category accuracy, shape (num_classes, ). ndarray: Per category evaluation metrics, shape (num_classes, ). """ total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64) total_area_union = torch.zeros((num_classes, ), dtype=torch.float64) total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64) total_area_label = torch.zeros((num_classes, ), dtype=torch.float64) for result, gt_seg_map in zip(results, gt_seg_maps): area_intersect, area_union, area_pred_label, area_label = \ intersect_and_union( result, gt_seg_map, num_classes) total_area_intersect += area_intersect total_area_union += area_union total_area_pred_label += area_pred_label total_area_label += area_label ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union, total_area_pred_label, total_area_label, metrics, nan_to_num, beta) return ret_metrics def total_area_to_metrics(total_area_intersect, total_area_union, total_area_pred_label, total_area_label, metrics=['mIoU'], nan_to_num=None, beta=1): """Calculate evaluation metrics Args: total_area_intersect (ndarray): The intersection of prediction and ground truth histogram on all classes. total_area_union (ndarray): The union of prediction and ground truth histogram on all classes. total_area_pred_label (ndarray): The prediction histogram on all classes. total_area_label (ndarray): The ground truth histogram on all classes. metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. nan_to_num (int, optional): If specified, NaN values will be replaced by the numbers defined by the user. Default: None. Returns: float: Overall accuracy on all images. ndarray: Per category accuracy, shape (num_classes, ). ndarray: Per category evaluation metrics, shape (num_classes, ). """ if isinstance(metrics, str): metrics = [metrics] if not set(metrics).issubset(set(_ALLOWED_METRICS)): raise KeyError('metrics {} is not supported'.format(metrics)) all_acc = total_area_intersect.sum() / total_area_label.sum() ret_metrics = OrderedDict({'aAcc': all_acc}) for metric in metrics: if metric == 'mIoU': iou = total_area_intersect / total_area_union acc = total_area_intersect / total_area_label ret_metrics['IoU'] = iou ret_metrics['Acc'] = acc elif metric == 'mDice': dice = 2 * total_area_intersect / ( total_area_pred_label + total_area_label) acc = total_area_intersect / total_area_label ret_metrics['Dice'] = dice ret_metrics['Acc'] = acc elif metric == 'mFscore': precision = total_area_intersect / total_area_pred_label recall = total_area_intersect / total_area_label f_value = torch.tensor( [f_score(x[0], x[1], beta) for x in zip(precision, recall)]) ret_metrics['Fscore'] = f_value ret_metrics['Precision'] = precision ret_metrics['Recall'] = recall ret_metrics = { metric: value.numpy() for metric, value in ret_metrics.items() } if nan_to_num is not None: ret_metrics = OrderedDict({ metric: np.nan_to_num(metric_value, nan=nan_to_num) for metric, metric_value in ret_metrics.items() }) return ret_metrics