pytorch metrics implementation (#430)

* pytorch metrics impl and test

* support list[str] input, delete unused test code and delete numpy version

* modify input data type

* add docstring and unitest of filename inputs

* add indents in docstring and use tempfile lib to create dir

* using with statement
This commit is contained in:
谢昕辰 2021-03-30 00:49:14 +08:00 committed by GitHub
parent 15faf716de
commit d474cfde4b
2 changed files with 97 additions and 41 deletions

View File

@ -1,5 +1,6 @@
import mmcv import mmcv
import numpy as np import numpy as np
import torch
def intersect_and_union(pred_label, def intersect_and_union(pred_label,
@ -11,8 +12,10 @@ def intersect_and_union(pred_label,
"""Calculate intersection and Union. """Calculate intersection and Union.
Args: Args:
pred_label (ndarray): Prediction segmentation map. pred_label (ndarray | str): Prediction segmentation map
label (ndarray): Ground truth segmentation map. or predict result filename.
label (ndarray | str): Ground truth segmentation map
or label filename.
num_classes (int): Number of categories. num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation. ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. The parameter will label_map (dict): Mapping old labels to new labels. The parameter will
@ -21,25 +24,29 @@ def intersect_and_union(pred_label,
work only when label is str. Default: False. work only when label is str. Default: False.
Returns: Returns:
ndarray: The intersection of prediction and ground truth histogram torch.Tensor: The intersection of prediction and ground truth
on all classes. histogram on all classes.
ndarray: The union of prediction and ground truth histogram on all torch.Tensor: The union of prediction and ground truth histogram on
classes. all classes.
ndarray: The prediction histogram on all classes. torch.Tensor: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes. torch.Tensor: The ground truth histogram on all classes.
""" """
if isinstance(pred_label, str): if isinstance(pred_label, str):
pred_label = np.load(pred_label) pred_label = torch.from_numpy(np.load(pred_label))
else:
pred_label = torch.from_numpy((pred_label))
if isinstance(label, str): if isinstance(label, str):
label = mmcv.imread(label, flag='unchanged', backend='pillow') label = torch.from_numpy(
# modify if custom classes mmcv.imread(label, flag='unchanged', backend='pillow'))
else:
label = torch.from_numpy(label)
if label_map is not None: if label_map is not None:
for old_id, new_id in label_map.items(): for old_id, new_id in label_map.items():
label[label == old_id] = new_id label[label == old_id] = new_id
if reduce_zero_label: if reduce_zero_label:
# avoid using underflow conversion
label[label == 0] = 255 label[label == 0] = 255
label = label - 1 label = label - 1
label[label == 254] = 255 label[label == 254] = 255
@ -49,13 +56,13 @@ def intersect_and_union(pred_label,
label = label[mask] label = label[mask]
intersect = pred_label[pred_label == label] intersect = pred_label[pred_label == label]
area_intersect, _ = np.histogram( area_intersect = torch.histc(
intersect, bins=np.arange(num_classes + 1)) intersect.float(), bins=(num_classes), min=0, max=num_classes)
area_pred_label, _ = np.histogram( area_pred_label = torch.histc(
pred_label, bins=np.arange(num_classes + 1)) pred_label.float(), bins=(num_classes), min=0, max=num_classes)
area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1)) area_label = torch.histc(
label.float(), bins=(num_classes), min=0, max=num_classes)
area_union = area_pred_label + area_label - area_intersect area_union = area_pred_label + area_label - area_intersect
return area_intersect, area_union, area_pred_label, area_label return area_intersect, area_union, area_pred_label, area_label
@ -68,8 +75,10 @@ def total_intersect_and_union(results,
"""Calculate Total Intersection and Union. """Calculate Total Intersection and Union.
Args: Args:
results (list[ndarray]): List of prediction segmentation maps. results (list[ndarray] | list[str]): List of prediction segmentation
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories. num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation. ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. Default: dict(). label_map (dict): Mapping old labels to new labels. Default: dict().
@ -83,23 +92,23 @@ def total_intersect_and_union(results,
ndarray: The prediction histogram on all classes. ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes. ndarray: The ground truth histogram on all classes.
""" """
num_imgs = len(results) num_imgs = len(results)
assert len(gt_seg_maps) == num_imgs assert len(gt_seg_maps) == num_imgs
total_area_intersect = np.zeros((num_classes, ), dtype=np.float) total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_union = np.zeros((num_classes, ), dtype=np.float) total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_pred_label = np.zeros((num_classes, ), dtype=np.float) total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_label = np.zeros((num_classes, ), dtype=np.float) total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
for i in range(num_imgs): for i in range(num_imgs):
area_intersect, area_union, area_pred_label, area_label = \ area_intersect, area_union, area_pred_label, area_label = \
intersect_and_union(results[i], gt_seg_maps[i], num_classes, intersect_and_union(
ignore_index, label_map, reduce_zero_label) results[i], gt_seg_maps[i], num_classes, ignore_index,
label_map, reduce_zero_label)
total_area_intersect += area_intersect total_area_intersect += area_intersect
total_area_union += area_union total_area_union += area_union
total_area_pred_label += area_pred_label total_area_pred_label += area_pred_label
total_area_label += area_label total_area_label += area_label
return total_area_intersect, total_area_union, \ return total_area_intersect, total_area_union, total_area_pred_label, \
total_area_pred_label, total_area_label total_area_label
def mean_iou(results, def mean_iou(results,
@ -112,8 +121,10 @@ def mean_iou(results,
"""Calculate Mean Intersection and Union (mIoU) """Calculate Mean Intersection and Union (mIoU)
Args: Args:
results (list[ndarray]): List of prediction segmentation maps. results (list[ndarray] | list[str]): List of prediction segmentation
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories. num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation. ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced nan_to_num (int, optional): If specified, NaN values will be replaced
@ -126,7 +137,6 @@ def mean_iou(results,
ndarray: Per category accuracy, shape (num_classes, ). ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category IoU, shape (num_classes, ). ndarray: Per category IoU, shape (num_classes, ).
""" """
all_acc, acc, iou = eval_metrics( all_acc, acc, iou = eval_metrics(
results=results, results=results,
gt_seg_maps=gt_seg_maps, gt_seg_maps=gt_seg_maps,
@ -149,8 +159,10 @@ def mean_dice(results,
"""Calculate Mean Dice (mDice) """Calculate Mean Dice (mDice)
Args: Args:
results (list[ndarray]): List of prediction segmentation maps. results (list[ndarray] | list[str]): List of prediction segmentation
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories. num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation. ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced nan_to_num (int, optional): If specified, NaN values will be replaced
@ -186,8 +198,10 @@ def eval_metrics(results,
reduce_zero_label=False): reduce_zero_label=False):
"""Calculate evaluation metrics """Calculate evaluation metrics
Args: Args:
results (list[ndarray]): List of prediction segmentation maps. results (list[ndarray] | list[str]): List of prediction segmentation
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories. num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation. ignore_index (int): Index that will be ignored in evaluation.
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
@ -200,17 +214,16 @@ def eval_metrics(results,
ndarray: Per category accuracy, shape (num_classes, ). ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evalution metrics, shape (num_classes, ). ndarray: Per category evalution metrics, shape (num_classes, ).
""" """
if isinstance(metrics, str): if isinstance(metrics, str):
metrics = [metrics] metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice'] allowed_metrics = ['mIoU', 'mDice']
if not set(metrics).issubset(set(allowed_metrics)): if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError('metrics {} is not supported'.format(metrics)) raise KeyError('metrics {} is not supported'.format(metrics))
total_area_intersect, total_area_union, total_area_pred_label, \ total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label = total_intersect_and_union(results, gt_seg_maps, total_area_label = total_intersect_and_union(
num_classes, ignore_index, results, gt_seg_maps, num_classes, ignore_index, label_map,
label_map, reduce_zero_label)
reduce_zero_label)
all_acc = total_area_intersect.sum() / total_area_label.sum() all_acc = total_area_intersect.sum() / total_area_label.sum()
acc = total_area_intersect / total_area_label acc = total_area_intersect / total_area_label
ret_metrics = [all_acc, acc] ret_metrics = [all_acc, acc]
@ -222,6 +235,7 @@ def eval_metrics(results,
dice = 2 * total_area_intersect / ( dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label) total_area_pred_label + total_area_label)
ret_metrics.append(dice) ret_metrics.append(dice)
ret_metrics = [metric.numpy() for metric in ret_metrics]
if nan_to_num is not None: if nan_to_num is not None:
ret_metrics = [ ret_metrics = [
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics

View File

@ -164,3 +164,45 @@ def test_mean_dice():
results, label, num_classes, ignore_index=255, nan_to_num=-1) results, label, num_classes, ignore_index=255, nan_to_num=-1)
assert acc[-1] == -1 assert acc[-1] == -1
assert iou[-1] == -1 assert iou[-1] == -1
def test_filename_inputs():
import cv2
import tempfile
def save_arr(input_arrays: list, title: str, is_image: bool, dir: str):
filenames = []
SUFFIX = '.png' if is_image else '.npy'
for idx, arr in enumerate(input_arrays):
filename = '{}/{}-{}{}'.format(dir, title, idx, SUFFIX)
if is_image:
cv2.imwrite(filename, arr)
else:
np.save(filename, arr)
filenames.append(filename)
return filenames
pred_size = (10, 512, 1024)
num_classes = 19
ignore_index = 255
results = np.random.randint(0, num_classes, size=pred_size)
labels = np.random.randint(0, num_classes, size=pred_size)
labels[:, 2, 5:10] = ignore_index
with tempfile.TemporaryDirectory() as temp_dir:
result_files = save_arr(results, 'pred', False, temp_dir)
label_files = save_arr(labels, 'label', True, temp_dir)
all_acc, acc, iou = eval_metrics(
result_files,
label_files,
num_classes,
ignore_index,
metrics='mIoU')
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, labels, num_classes,
ignore_index)
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(iou, iou_l)