2021-01-10 15:47:31 +08:00
|
|
|
import mmcv
|
2020-11-24 11:21:22 +08:00
|
|
|
import numpy as np
|
2021-03-30 00:49:14 +08:00
|
|
|
import torch
|
2020-11-24 11:21:22 +08:00
|
|
|
|
|
|
|
|
2021-01-10 15:47:31 +08:00
|
|
|
def intersect_and_union(pred_label,
|
|
|
|
label,
|
|
|
|
num_classes,
|
|
|
|
ignore_index,
|
|
|
|
label_map=dict(),
|
|
|
|
reduce_zero_label=False):
|
2020-11-24 11:21:22 +08:00
|
|
|
"""Calculate intersection and Union.
|
|
|
|
|
|
|
|
Args:
|
2021-03-30 00:49:14 +08:00
|
|
|
pred_label (ndarray | str): Prediction segmentation map
|
|
|
|
or predict result filename.
|
|
|
|
label (ndarray | str): Ground truth segmentation map
|
|
|
|
or label filename.
|
2021-01-10 15:47:31 +08:00
|
|
|
num_classes (int): Number of categories.
|
2020-11-24 11:21:22 +08:00
|
|
|
ignore_index (int): Index that will be ignored in evaluation.
|
2021-01-10 15:47:31 +08:00
|
|
|
label_map (dict): Mapping old labels to new labels. The parameter will
|
|
|
|
work only when label is str. Default: dict().
|
|
|
|
reduce_zero_label (bool): Wether ignore zero label. The parameter will
|
|
|
|
work only when label is str. Default: False.
|
2020-11-24 11:21:22 +08:00
|
|
|
|
|
|
|
Returns:
|
2021-03-30 00:49:14 +08:00
|
|
|
torch.Tensor: The intersection of prediction and ground truth
|
|
|
|
histogram on all classes.
|
|
|
|
torch.Tensor: The union of prediction and ground truth histogram on
|
|
|
|
all classes.
|
|
|
|
torch.Tensor: The prediction histogram on all classes.
|
|
|
|
torch.Tensor: The ground truth histogram on all classes.
|
2020-11-24 11:21:22 +08:00
|
|
|
"""
|
|
|
|
|
2021-01-10 15:47:31 +08:00
|
|
|
if isinstance(pred_label, str):
|
2021-03-30 00:49:14 +08:00
|
|
|
pred_label = torch.from_numpy(np.load(pred_label))
|
|
|
|
else:
|
|
|
|
pred_label = torch.from_numpy((pred_label))
|
2021-01-10 15:47:31 +08:00
|
|
|
|
|
|
|
if isinstance(label, str):
|
2021-03-30 00:49:14 +08:00
|
|
|
label = torch.from_numpy(
|
|
|
|
mmcv.imread(label, flag='unchanged', backend='pillow'))
|
|
|
|
else:
|
|
|
|
label = torch.from_numpy(label)
|
|
|
|
|
2021-01-24 18:17:59 +08:00
|
|
|
if label_map is not None:
|
|
|
|
for old_id, new_id in label_map.items():
|
|
|
|
label[label == old_id] = new_id
|
|
|
|
if reduce_zero_label:
|
|
|
|
label[label == 0] = 255
|
|
|
|
label = label - 1
|
|
|
|
label[label == 254] = 255
|
2021-01-10 15:47:31 +08:00
|
|
|
|
2020-11-24 11:21:22 +08:00
|
|
|
mask = (label != ignore_index)
|
|
|
|
pred_label = pred_label[mask]
|
|
|
|
label = label[mask]
|
|
|
|
|
|
|
|
intersect = pred_label[pred_label == label]
|
2021-03-30 00:49:14 +08:00
|
|
|
area_intersect = torch.histc(
|
|
|
|
intersect.float(), bins=(num_classes), min=0, max=num_classes)
|
|
|
|
area_pred_label = torch.histc(
|
|
|
|
pred_label.float(), bins=(num_classes), min=0, max=num_classes)
|
|
|
|
area_label = torch.histc(
|
|
|
|
label.float(), bins=(num_classes), min=0, max=num_classes)
|
2020-11-24 11:21:22 +08:00
|
|
|
area_union = area_pred_label + area_label - area_intersect
|
|
|
|
return area_intersect, area_union, area_pred_label, area_label
|
|
|
|
|
|
|
|
|
2021-01-10 15:47:31 +08:00
|
|
|
def total_intersect_and_union(results,
|
|
|
|
gt_seg_maps,
|
|
|
|
num_classes,
|
|
|
|
ignore_index,
|
|
|
|
label_map=dict(),
|
|
|
|
reduce_zero_label=False):
|
2020-11-24 11:21:22 +08:00
|
|
|
"""Calculate Total Intersection and Union.
|
|
|
|
|
|
|
|
Args:
|
2021-03-30 00:49:14 +08:00
|
|
|
results (list[ndarray] | list[str]): List of prediction segmentation
|
|
|
|
maps or list of prediction result filenames.
|
|
|
|
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
|
|
|
segmentation maps or list of label filenames.
|
2021-01-10 15:47:31 +08:00
|
|
|
num_classes (int): Number of categories.
|
2020-11-24 11:21:22 +08:00
|
|
|
ignore_index (int): Index that will be ignored in evaluation.
|
2021-01-10 15:47:31 +08:00
|
|
|
label_map (dict): Mapping old labels to new labels. Default: dict().
|
|
|
|
reduce_zero_label (bool): Wether ignore zero label. Default: False.
|
2020-11-24 11:21:22 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
ndarray: The intersection of prediction and ground truth histogram
|
2021-01-10 15:47:31 +08:00
|
|
|
on all classes.
|
2020-11-24 11:21:22 +08:00
|
|
|
ndarray: The union of prediction and ground truth histogram on all
|
2021-01-10 15:47:31 +08:00
|
|
|
classes.
|
2020-11-24 11:21:22 +08:00
|
|
|
ndarray: The prediction histogram on all classes.
|
|
|
|
ndarray: The ground truth histogram on all classes.
|
|
|
|
"""
|
|
|
|
num_imgs = len(results)
|
|
|
|
assert len(gt_seg_maps) == num_imgs
|
2021-03-30 00:49:14 +08:00
|
|
|
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
|
|
|
|
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
|
|
|
|
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
|
|
|
|
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
|
2020-11-24 11:21:22 +08:00
|
|
|
for i in range(num_imgs):
|
|
|
|
area_intersect, area_union, area_pred_label, area_label = \
|
2021-03-30 00:49:14 +08:00
|
|
|
intersect_and_union(
|
|
|
|
results[i], gt_seg_maps[i], num_classes, ignore_index,
|
|
|
|
label_map, reduce_zero_label)
|
2020-11-24 11:21:22 +08:00
|
|
|
total_area_intersect += area_intersect
|
|
|
|
total_area_union += area_union
|
|
|
|
total_area_pred_label += area_pred_label
|
|
|
|
total_area_label += area_label
|
2021-03-30 00:49:14 +08:00
|
|
|
return total_area_intersect, total_area_union, total_area_pred_label, \
|
|
|
|
total_area_label
|
2020-11-24 11:21:22 +08:00
|
|
|
|
|
|
|
|
2021-01-10 15:47:31 +08:00
|
|
|
def mean_iou(results,
|
|
|
|
gt_seg_maps,
|
|
|
|
num_classes,
|
|
|
|
ignore_index,
|
|
|
|
nan_to_num=None,
|
|
|
|
label_map=dict(),
|
|
|
|
reduce_zero_label=False):
|
2020-11-24 11:21:22 +08:00
|
|
|
"""Calculate Mean Intersection and Union (mIoU)
|
|
|
|
|
|
|
|
Args:
|
2021-03-30 00:49:14 +08:00
|
|
|
results (list[ndarray] | list[str]): List of prediction segmentation
|
|
|
|
maps or list of prediction result filenames.
|
|
|
|
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
|
|
|
segmentation maps or list of label filenames.
|
2021-01-10 15:47:31 +08:00
|
|
|
num_classes (int): Number of categories.
|
2020-11-24 11:21:22 +08:00
|
|
|
ignore_index (int): Index that will be ignored in evaluation.
|
|
|
|
nan_to_num (int, optional): If specified, NaN values will be replaced
|
|
|
|
by the numbers defined by the user. Default: None.
|
2021-01-10 15:47:31 +08:00
|
|
|
label_map (dict): Mapping old labels to new labels. Default: dict().
|
|
|
|
reduce_zero_label (bool): Wether ignore zero label. Default: False.
|
2020-11-24 11:21:22 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
float: Overall accuracy on all images.
|
2021-01-10 15:47:31 +08:00
|
|
|
ndarray: Per category accuracy, shape (num_classes, ).
|
|
|
|
ndarray: Per category IoU, shape (num_classes, ).
|
2020-11-24 11:21:22 +08:00
|
|
|
"""
|
|
|
|
all_acc, acc, iou = eval_metrics(
|
|
|
|
results=results,
|
|
|
|
gt_seg_maps=gt_seg_maps,
|
|
|
|
num_classes=num_classes,
|
|
|
|
ignore_index=ignore_index,
|
|
|
|
metrics=['mIoU'],
|
2021-01-10 15:47:31 +08:00
|
|
|
nan_to_num=nan_to_num,
|
|
|
|
label_map=label_map,
|
|
|
|
reduce_zero_label=reduce_zero_label)
|
2020-11-24 11:21:22 +08:00
|
|
|
return all_acc, acc, iou
|
|
|
|
|
|
|
|
|
|
|
|
def mean_dice(results,
|
|
|
|
gt_seg_maps,
|
|
|
|
num_classes,
|
|
|
|
ignore_index,
|
2021-01-10 15:47:31 +08:00
|
|
|
nan_to_num=None,
|
|
|
|
label_map=dict(),
|
|
|
|
reduce_zero_label=False):
|
2020-11-24 11:21:22 +08:00
|
|
|
"""Calculate Mean Dice (mDice)
|
|
|
|
|
|
|
|
Args:
|
2021-03-30 00:49:14 +08:00
|
|
|
results (list[ndarray] | list[str]): List of prediction segmentation
|
|
|
|
maps or list of prediction result filenames.
|
|
|
|
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
|
|
|
segmentation maps or list of label filenames.
|
2021-01-10 15:47:31 +08:00
|
|
|
num_classes (int): Number of categories.
|
2020-11-24 11:21:22 +08:00
|
|
|
ignore_index (int): Index that will be ignored in evaluation.
|
|
|
|
nan_to_num (int, optional): If specified, NaN values will be replaced
|
|
|
|
by the numbers defined by the user. Default: None.
|
2021-01-10 15:47:31 +08:00
|
|
|
label_map (dict): Mapping old labels to new labels. Default: dict().
|
|
|
|
reduce_zero_label (bool): Wether ignore zero label. Default: False.
|
2020-11-24 11:21:22 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
float: Overall accuracy on all images.
|
2021-01-10 15:47:31 +08:00
|
|
|
ndarray: Per category accuracy, shape (num_classes, ).
|
|
|
|
ndarray: Per category dice, shape (num_classes, ).
|
2020-11-24 11:21:22 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
all_acc, acc, dice = eval_metrics(
|
|
|
|
results=results,
|
|
|
|
gt_seg_maps=gt_seg_maps,
|
|
|
|
num_classes=num_classes,
|
|
|
|
ignore_index=ignore_index,
|
|
|
|
metrics=['mDice'],
|
2021-01-10 15:47:31 +08:00
|
|
|
nan_to_num=nan_to_num,
|
|
|
|
label_map=label_map,
|
|
|
|
reduce_zero_label=reduce_zero_label)
|
2020-11-24 11:21:22 +08:00
|
|
|
return all_acc, acc, dice
|
|
|
|
|
|
|
|
|
|
|
|
def eval_metrics(results,
|
|
|
|
gt_seg_maps,
|
|
|
|
num_classes,
|
|
|
|
ignore_index,
|
|
|
|
metrics=['mIoU'],
|
2021-01-10 15:47:31 +08:00
|
|
|
nan_to_num=None,
|
|
|
|
label_map=dict(),
|
|
|
|
reduce_zero_label=False):
|
2020-11-24 11:21:22 +08:00
|
|
|
"""Calculate evaluation metrics
|
|
|
|
Args:
|
2021-03-30 00:49:14 +08:00
|
|
|
results (list[ndarray] | list[str]): List of prediction segmentation
|
|
|
|
maps or list of prediction result filenames.
|
|
|
|
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
|
|
|
segmentation maps or list of label filenames.
|
2021-01-10 15:47:31 +08:00
|
|
|
num_classes (int): Number of categories.
|
2020-11-24 11:21:22 +08:00
|
|
|
ignore_index (int): Index that will be ignored in evaluation.
|
|
|
|
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
|
|
|
|
nan_to_num (int, optional): If specified, NaN values will be replaced
|
|
|
|
by the numbers defined by the user. Default: None.
|
2021-01-10 15:47:31 +08:00
|
|
|
label_map (dict): Mapping old labels to new labels. Default: dict().
|
|
|
|
reduce_zero_label (bool): Wether ignore zero label. Default: False.
|
2020-11-24 11:21:22 +08:00
|
|
|
Returns:
|
|
|
|
float: Overall accuracy on all images.
|
2021-01-10 15:47:31 +08:00
|
|
|
ndarray: Per category accuracy, shape (num_classes, ).
|
|
|
|
ndarray: Per category evalution metrics, shape (num_classes, ).
|
2020-11-24 11:21:22 +08:00
|
|
|
"""
|
|
|
|
if isinstance(metrics, str):
|
|
|
|
metrics = [metrics]
|
|
|
|
allowed_metrics = ['mIoU', 'mDice']
|
|
|
|
if not set(metrics).issubset(set(allowed_metrics)):
|
|
|
|
raise KeyError('metrics {} is not supported'.format(metrics))
|
2021-03-30 00:49:14 +08:00
|
|
|
|
2020-11-24 11:21:22 +08:00
|
|
|
total_area_intersect, total_area_union, total_area_pred_label, \
|
2021-03-30 00:49:14 +08:00
|
|
|
total_area_label = total_intersect_and_union(
|
|
|
|
results, gt_seg_maps, num_classes, ignore_index, label_map,
|
|
|
|
reduce_zero_label)
|
2020-11-24 11:21:22 +08:00
|
|
|
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
|
|
|
acc = total_area_intersect / total_area_label
|
|
|
|
ret_metrics = [all_acc, acc]
|
|
|
|
for metric in metrics:
|
|
|
|
if metric == 'mIoU':
|
|
|
|
iou = total_area_intersect / total_area_union
|
|
|
|
ret_metrics.append(iou)
|
|
|
|
elif metric == 'mDice':
|
|
|
|
dice = 2 * total_area_intersect / (
|
|
|
|
total_area_pred_label + total_area_label)
|
|
|
|
ret_metrics.append(dice)
|
2021-03-30 00:49:14 +08:00
|
|
|
ret_metrics = [metric.numpy() for metric in ret_metrics]
|
2020-11-24 11:21:22 +08:00
|
|
|
if nan_to_num is not None:
|
|
|
|
ret_metrics = [
|
|
|
|
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics
|
|
|
|
]
|
|
|
|
return ret_metrics
|