pytorch metrics implementation (#430)

* pytorch metrics impl and test

* support list[str] input, delete unused test code and delete numpy version

* modify input data type

* add docstring and unitest of filename inputs

* add indents in docstring and use tempfile lib to create dir

* using with statement
pull/450/head
谢昕辰 2021-03-30 00:49:14 +08:00 committed by GitHub
parent 340132dcf4
commit e86a87f060
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 97 additions and 41 deletions

View File

@ -1,5 +1,6 @@
import mmcv
import numpy as np
import torch
def intersect_and_union(pred_label,
@ -11,8 +12,10 @@ def intersect_and_union(pred_label,
"""Calculate intersection and Union.
Args:
pred_label (ndarray): Prediction segmentation map.
label (ndarray): Ground truth segmentation map.
pred_label (ndarray | str): Prediction segmentation map
or predict result filename.
label (ndarray | str): Ground truth segmentation map
or label filename.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. The parameter will
@ -21,25 +24,29 @@ def intersect_and_union(pred_label,
work only when label is str. Default: False.
Returns:
ndarray: The intersection of prediction and ground truth histogram
on all classes.
ndarray: The union of prediction and ground truth histogram on all
classes.
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
torch.Tensor: The intersection of prediction and ground truth
histogram on all classes.
torch.Tensor: The union of prediction and ground truth histogram on
all classes.
torch.Tensor: The prediction histogram on all classes.
torch.Tensor: The ground truth histogram on all classes.
"""
if isinstance(pred_label, str):
pred_label = np.load(pred_label)
pred_label = torch.from_numpy(np.load(pred_label))
else:
pred_label = torch.from_numpy((pred_label))
if isinstance(label, str):
label = mmcv.imread(label, flag='unchanged', backend='pillow')
# modify if custom classes
label = torch.from_numpy(
mmcv.imread(label, flag='unchanged', backend='pillow'))
else:
label = torch.from_numpy(label)
if label_map is not None:
for old_id, new_id in label_map.items():
label[label == old_id] = new_id
if reduce_zero_label:
# avoid using underflow conversion
label[label == 0] = 255
label = label - 1
label[label == 254] = 255
@ -49,13 +56,13 @@ def intersect_and_union(pred_label,
label = label[mask]
intersect = pred_label[pred_label == label]
area_intersect, _ = np.histogram(
intersect, bins=np.arange(num_classes + 1))
area_pred_label, _ = np.histogram(
pred_label, bins=np.arange(num_classes + 1))
area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1))
area_intersect = torch.histc(
intersect.float(), bins=(num_classes), min=0, max=num_classes)
area_pred_label = torch.histc(
pred_label.float(), bins=(num_classes), min=0, max=num_classes)
area_label = torch.histc(
label.float(), bins=(num_classes), min=0, max=num_classes)
area_union = area_pred_label + area_label - area_intersect
return area_intersect, area_union, area_pred_label, area_label
@ -68,8 +75,10 @@ def total_intersect_and_union(results,
"""Calculate Total Intersection and Union.
Args:
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. Default: dict().
@ -83,23 +92,23 @@ def total_intersect_and_union(results,
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""
num_imgs = len(results)
assert len(gt_seg_maps) == num_imgs
total_area_intersect = np.zeros((num_classes, ), dtype=np.float)
total_area_union = np.zeros((num_classes, ), dtype=np.float)
total_area_pred_label = np.zeros((num_classes, ), dtype=np.float)
total_area_label = np.zeros((num_classes, ), dtype=np.float)
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
for i in range(num_imgs):
area_intersect, area_union, area_pred_label, area_label = \
intersect_and_union(results[i], gt_seg_maps[i], num_classes,
ignore_index, label_map, reduce_zero_label)
intersect_and_union(
results[i], gt_seg_maps[i], num_classes, ignore_index,
label_map, reduce_zero_label)
total_area_intersect += area_intersect
total_area_union += area_union
total_area_pred_label += area_pred_label
total_area_label += area_label
return total_area_intersect, total_area_union, \
total_area_pred_label, total_area_label
return total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label
def mean_iou(results,
@ -112,8 +121,10 @@ def mean_iou(results,
"""Calculate Mean Intersection and Union (mIoU)
Args:
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
@ -126,7 +137,6 @@ def mean_iou(results,
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category IoU, shape (num_classes, ).
"""
all_acc, acc, iou = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
@ -149,8 +159,10 @@ def mean_dice(results,
"""Calculate Mean Dice (mDice)
Args:
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
@ -186,8 +198,10 @@ def eval_metrics(results,
reduce_zero_label=False):
"""Calculate evaluation metrics
Args:
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
@ -200,17 +214,16 @@ def eval_metrics(results,
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evalution metrics, shape (num_classes, ).
"""
if isinstance(metrics, str):
metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice']
if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError('metrics {} is not supported'.format(metrics))
total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label = total_intersect_and_union(results, gt_seg_maps,
num_classes, ignore_index,
label_map,
reduce_zero_label)
total_area_label = total_intersect_and_union(
results, gt_seg_maps, num_classes, ignore_index, label_map,
reduce_zero_label)
all_acc = total_area_intersect.sum() / total_area_label.sum()
acc = total_area_intersect / total_area_label
ret_metrics = [all_acc, acc]
@ -222,6 +235,7 @@ def eval_metrics(results,
dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label)
ret_metrics.append(dice)
ret_metrics = [metric.numpy() for metric in ret_metrics]
if nan_to_num is not None:
ret_metrics = [
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics

View File

@ -164,3 +164,45 @@ def test_mean_dice():
results, label, num_classes, ignore_index=255, nan_to_num=-1)
assert acc[-1] == -1
assert iou[-1] == -1
def test_filename_inputs():
import cv2
import tempfile
def save_arr(input_arrays: list, title: str, is_image: bool, dir: str):
filenames = []
SUFFIX = '.png' if is_image else '.npy'
for idx, arr in enumerate(input_arrays):
filename = '{}/{}-{}{}'.format(dir, title, idx, SUFFIX)
if is_image:
cv2.imwrite(filename, arr)
else:
np.save(filename, arr)
filenames.append(filename)
return filenames
pred_size = (10, 512, 1024)
num_classes = 19
ignore_index = 255
results = np.random.randint(0, num_classes, size=pred_size)
labels = np.random.randint(0, num_classes, size=pred_size)
labels[:, 2, 5:10] = ignore_index
with tempfile.TemporaryDirectory() as temp_dir:
result_files = save_arr(results, 'pred', False, temp_dir)
label_files = save_arr(labels, 'label', True, temp_dir)
all_acc, acc, iou = eval_metrics(
result_files,
label_files,
num_classes,
ignore_index,
metrics='mIoU')
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, labels, num_classes,
ignore_index)
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(iou, iou_l)