diff --git a/mmseg/core/evaluation/metrics.py b/mmseg/core/evaluation/metrics.py index 95b096e7a..0f182b1c0 100644 --- a/mmseg/core/evaluation/metrics.py +++ b/mmseg/core/evaluation/metrics.py @@ -1,5 +1,6 @@ import mmcv import numpy as np +import torch def intersect_and_union(pred_label, @@ -11,8 +12,10 @@ def intersect_and_union(pred_label, """Calculate intersection and Union. Args: - pred_label (ndarray): Prediction segmentation map. - label (ndarray): Ground truth segmentation map. + pred_label (ndarray | str): Prediction segmentation map + or predict result filename. + label (ndarray | str): Ground truth segmentation map + or label filename. num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. label_map (dict): Mapping old labels to new labels. The parameter will @@ -21,25 +24,29 @@ def intersect_and_union(pred_label, work only when label is str. Default: False. Returns: - ndarray: The intersection of prediction and ground truth histogram - on all classes. - ndarray: The union of prediction and ground truth histogram on all - classes. - ndarray: The prediction histogram on all classes. - ndarray: The ground truth histogram on all classes. + torch.Tensor: The intersection of prediction and ground truth + histogram on all classes. + torch.Tensor: The union of prediction and ground truth histogram on + all classes. + torch.Tensor: The prediction histogram on all classes. + torch.Tensor: The ground truth histogram on all classes. """ if isinstance(pred_label, str): - pred_label = np.load(pred_label) + pred_label = torch.from_numpy(np.load(pred_label)) + else: + pred_label = torch.from_numpy((pred_label)) if isinstance(label, str): - label = mmcv.imread(label, flag='unchanged', backend='pillow') - # modify if custom classes + label = torch.from_numpy( + mmcv.imread(label, flag='unchanged', backend='pillow')) + else: + label = torch.from_numpy(label) + if label_map is not None: for old_id, new_id in label_map.items(): label[label == old_id] = new_id if reduce_zero_label: - # avoid using underflow conversion label[label == 0] = 255 label = label - 1 label[label == 254] = 255 @@ -49,13 +56,13 @@ def intersect_and_union(pred_label, label = label[mask] intersect = pred_label[pred_label == label] - area_intersect, _ = np.histogram( - intersect, bins=np.arange(num_classes + 1)) - area_pred_label, _ = np.histogram( - pred_label, bins=np.arange(num_classes + 1)) - area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1)) + area_intersect = torch.histc( + intersect.float(), bins=(num_classes), min=0, max=num_classes) + area_pred_label = torch.histc( + pred_label.float(), bins=(num_classes), min=0, max=num_classes) + area_label = torch.histc( + label.float(), bins=(num_classes), min=0, max=num_classes) area_union = area_pred_label + area_label - area_intersect - return area_intersect, area_union, area_pred_label, area_label @@ -68,8 +75,10 @@ def total_intersect_and_union(results, """Calculate Total Intersection and Union. Args: - results (list[ndarray]): List of prediction segmentation maps. - gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. label_map (dict): Mapping old labels to new labels. Default: dict(). @@ -83,23 +92,23 @@ def total_intersect_and_union(results, ndarray: The prediction histogram on all classes. ndarray: The ground truth histogram on all classes. """ - num_imgs = len(results) assert len(gt_seg_maps) == num_imgs - total_area_intersect = np.zeros((num_classes, ), dtype=np.float) - total_area_union = np.zeros((num_classes, ), dtype=np.float) - total_area_pred_label = np.zeros((num_classes, ), dtype=np.float) - total_area_label = np.zeros((num_classes, ), dtype=np.float) + total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64) + total_area_union = torch.zeros((num_classes, ), dtype=torch.float64) + total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64) + total_area_label = torch.zeros((num_classes, ), dtype=torch.float64) for i in range(num_imgs): area_intersect, area_union, area_pred_label, area_label = \ - intersect_and_union(results[i], gt_seg_maps[i], num_classes, - ignore_index, label_map, reduce_zero_label) + intersect_and_union( + results[i], gt_seg_maps[i], num_classes, ignore_index, + label_map, reduce_zero_label) total_area_intersect += area_intersect total_area_union += area_union total_area_pred_label += area_pred_label total_area_label += area_label - return total_area_intersect, total_area_union, \ - total_area_pred_label, total_area_label + return total_area_intersect, total_area_union, total_area_pred_label, \ + total_area_label def mean_iou(results, @@ -112,8 +121,10 @@ def mean_iou(results, """Calculate Mean Intersection and Union (mIoU) Args: - results (list[ndarray]): List of prediction segmentation maps. - gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. nan_to_num (int, optional): If specified, NaN values will be replaced @@ -126,7 +137,6 @@ def mean_iou(results, ndarray: Per category accuracy, shape (num_classes, ). ndarray: Per category IoU, shape (num_classes, ). """ - all_acc, acc, iou = eval_metrics( results=results, gt_seg_maps=gt_seg_maps, @@ -149,8 +159,10 @@ def mean_dice(results, """Calculate Mean Dice (mDice) Args: - results (list[ndarray]): List of prediction segmentation maps. - gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. nan_to_num (int, optional): If specified, NaN values will be replaced @@ -186,8 +198,10 @@ def eval_metrics(results, reduce_zero_label=False): """Calculate evaluation metrics Args: - results (list[ndarray]): List of prediction segmentation maps. - gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. @@ -200,17 +214,16 @@ def eval_metrics(results, ndarray: Per category accuracy, shape (num_classes, ). ndarray: Per category evalution metrics, shape (num_classes, ). """ - if isinstance(metrics, str): metrics = [metrics] allowed_metrics = ['mIoU', 'mDice'] if not set(metrics).issubset(set(allowed_metrics)): raise KeyError('metrics {} is not supported'.format(metrics)) + total_area_intersect, total_area_union, total_area_pred_label, \ - total_area_label = total_intersect_and_union(results, gt_seg_maps, - num_classes, ignore_index, - label_map, - reduce_zero_label) + total_area_label = total_intersect_and_union( + results, gt_seg_maps, num_classes, ignore_index, label_map, + reduce_zero_label) all_acc = total_area_intersect.sum() / total_area_label.sum() acc = total_area_intersect / total_area_label ret_metrics = [all_acc, acc] @@ -222,6 +235,7 @@ def eval_metrics(results, dice = 2 * total_area_intersect / ( total_area_pred_label + total_area_label) ret_metrics.append(dice) + ret_metrics = [metric.numpy() for metric in ret_metrics] if nan_to_num is not None: ret_metrics = [ np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 023bbb0a5..2033617c2 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -164,3 +164,45 @@ def test_mean_dice(): results, label, num_classes, ignore_index=255, nan_to_num=-1) assert acc[-1] == -1 assert iou[-1] == -1 + + +def test_filename_inputs(): + import cv2 + import tempfile + + def save_arr(input_arrays: list, title: str, is_image: bool, dir: str): + filenames = [] + SUFFIX = '.png' if is_image else '.npy' + for idx, arr in enumerate(input_arrays): + filename = '{}/{}-{}{}'.format(dir, title, idx, SUFFIX) + if is_image: + cv2.imwrite(filename, arr) + else: + np.save(filename, arr) + filenames.append(filename) + return filenames + + pred_size = (10, 512, 1024) + num_classes = 19 + ignore_index = 255 + results = np.random.randint(0, num_classes, size=pred_size) + labels = np.random.randint(0, num_classes, size=pred_size) + labels[:, 2, 5:10] = ignore_index + + with tempfile.TemporaryDirectory() as temp_dir: + + result_files = save_arr(results, 'pred', False, temp_dir) + label_files = save_arr(labels, 'label', True, temp_dir) + + all_acc, acc, iou = eval_metrics( + result_files, + label_files, + num_classes, + ignore_index, + metrics='mIoU') + + all_acc_l, acc_l, iou_l = legacy_mean_iou(results, labels, num_classes, + ignore_index) + assert all_acc == all_acc_l + assert np.allclose(acc, acc_l) + assert np.allclose(iou, iou_l)