pytorch metrics implementation (#430)
* pytorch metrics impl and test * support list[str] input, delete unused test code and delete numpy version * modify input data type * add docstring and unitest of filename inputs * add indents in docstring and use tempfile lib to create dir * using with statementpull/450/head
parent
340132dcf4
commit
e86a87f060
|
@ -1,5 +1,6 @@
|
|||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def intersect_and_union(pred_label,
|
||||
|
@ -11,8 +12,10 @@ def intersect_and_union(pred_label,
|
|||
"""Calculate intersection and Union.
|
||||
|
||||
Args:
|
||||
pred_label (ndarray): Prediction segmentation map.
|
||||
label (ndarray): Ground truth segmentation map.
|
||||
pred_label (ndarray | str): Prediction segmentation map
|
||||
or predict result filename.
|
||||
label (ndarray | str): Ground truth segmentation map
|
||||
or label filename.
|
||||
num_classes (int): Number of categories.
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
label_map (dict): Mapping old labels to new labels. The parameter will
|
||||
|
@ -21,25 +24,29 @@ def intersect_and_union(pred_label,
|
|||
work only when label is str. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: The intersection of prediction and ground truth histogram
|
||||
on all classes.
|
||||
ndarray: The union of prediction and ground truth histogram on all
|
||||
classes.
|
||||
ndarray: The prediction histogram on all classes.
|
||||
ndarray: The ground truth histogram on all classes.
|
||||
torch.Tensor: The intersection of prediction and ground truth
|
||||
histogram on all classes.
|
||||
torch.Tensor: The union of prediction and ground truth histogram on
|
||||
all classes.
|
||||
torch.Tensor: The prediction histogram on all classes.
|
||||
torch.Tensor: The ground truth histogram on all classes.
|
||||
"""
|
||||
|
||||
if isinstance(pred_label, str):
|
||||
pred_label = np.load(pred_label)
|
||||
pred_label = torch.from_numpy(np.load(pred_label))
|
||||
else:
|
||||
pred_label = torch.from_numpy((pred_label))
|
||||
|
||||
if isinstance(label, str):
|
||||
label = mmcv.imread(label, flag='unchanged', backend='pillow')
|
||||
# modify if custom classes
|
||||
label = torch.from_numpy(
|
||||
mmcv.imread(label, flag='unchanged', backend='pillow'))
|
||||
else:
|
||||
label = torch.from_numpy(label)
|
||||
|
||||
if label_map is not None:
|
||||
for old_id, new_id in label_map.items():
|
||||
label[label == old_id] = new_id
|
||||
if reduce_zero_label:
|
||||
# avoid using underflow conversion
|
||||
label[label == 0] = 255
|
||||
label = label - 1
|
||||
label[label == 254] = 255
|
||||
|
@ -49,13 +56,13 @@ def intersect_and_union(pred_label,
|
|||
label = label[mask]
|
||||
|
||||
intersect = pred_label[pred_label == label]
|
||||
area_intersect, _ = np.histogram(
|
||||
intersect, bins=np.arange(num_classes + 1))
|
||||
area_pred_label, _ = np.histogram(
|
||||
pred_label, bins=np.arange(num_classes + 1))
|
||||
area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1))
|
||||
area_intersect = torch.histc(
|
||||
intersect.float(), bins=(num_classes), min=0, max=num_classes)
|
||||
area_pred_label = torch.histc(
|
||||
pred_label.float(), bins=(num_classes), min=0, max=num_classes)
|
||||
area_label = torch.histc(
|
||||
label.float(), bins=(num_classes), min=0, max=num_classes)
|
||||
area_union = area_pred_label + area_label - area_intersect
|
||||
|
||||
return area_intersect, area_union, area_pred_label, area_label
|
||||
|
||||
|
||||
|
@ -68,8 +75,10 @@ def total_intersect_and_union(results,
|
|||
"""Calculate Total Intersection and Union.
|
||||
|
||||
Args:
|
||||
results (list[ndarray]): List of prediction segmentation maps.
|
||||
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
|
||||
results (list[ndarray] | list[str]): List of prediction segmentation
|
||||
maps or list of prediction result filenames.
|
||||
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
||||
segmentation maps or list of label filenames.
|
||||
num_classes (int): Number of categories.
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
label_map (dict): Mapping old labels to new labels. Default: dict().
|
||||
|
@ -83,23 +92,23 @@ def total_intersect_and_union(results,
|
|||
ndarray: The prediction histogram on all classes.
|
||||
ndarray: The ground truth histogram on all classes.
|
||||
"""
|
||||
|
||||
num_imgs = len(results)
|
||||
assert len(gt_seg_maps) == num_imgs
|
||||
total_area_intersect = np.zeros((num_classes, ), dtype=np.float)
|
||||
total_area_union = np.zeros((num_classes, ), dtype=np.float)
|
||||
total_area_pred_label = np.zeros((num_classes, ), dtype=np.float)
|
||||
total_area_label = np.zeros((num_classes, ), dtype=np.float)
|
||||
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
|
||||
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
|
||||
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
|
||||
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
|
||||
for i in range(num_imgs):
|
||||
area_intersect, area_union, area_pred_label, area_label = \
|
||||
intersect_and_union(results[i], gt_seg_maps[i], num_classes,
|
||||
ignore_index, label_map, reduce_zero_label)
|
||||
intersect_and_union(
|
||||
results[i], gt_seg_maps[i], num_classes, ignore_index,
|
||||
label_map, reduce_zero_label)
|
||||
total_area_intersect += area_intersect
|
||||
total_area_union += area_union
|
||||
total_area_pred_label += area_pred_label
|
||||
total_area_label += area_label
|
||||
return total_area_intersect, total_area_union, \
|
||||
total_area_pred_label, total_area_label
|
||||
return total_area_intersect, total_area_union, total_area_pred_label, \
|
||||
total_area_label
|
||||
|
||||
|
||||
def mean_iou(results,
|
||||
|
@ -112,8 +121,10 @@ def mean_iou(results,
|
|||
"""Calculate Mean Intersection and Union (mIoU)
|
||||
|
||||
Args:
|
||||
results (list[ndarray]): List of prediction segmentation maps.
|
||||
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
|
||||
results (list[ndarray] | list[str]): List of prediction segmentation
|
||||
maps or list of prediction result filenames.
|
||||
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
||||
segmentation maps or list of label filenames.
|
||||
num_classes (int): Number of categories.
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
nan_to_num (int, optional): If specified, NaN values will be replaced
|
||||
|
@ -126,7 +137,6 @@ def mean_iou(results,
|
|||
ndarray: Per category accuracy, shape (num_classes, ).
|
||||
ndarray: Per category IoU, shape (num_classes, ).
|
||||
"""
|
||||
|
||||
all_acc, acc, iou = eval_metrics(
|
||||
results=results,
|
||||
gt_seg_maps=gt_seg_maps,
|
||||
|
@ -149,8 +159,10 @@ def mean_dice(results,
|
|||
"""Calculate Mean Dice (mDice)
|
||||
|
||||
Args:
|
||||
results (list[ndarray]): List of prediction segmentation maps.
|
||||
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
|
||||
results (list[ndarray] | list[str]): List of prediction segmentation
|
||||
maps or list of prediction result filenames.
|
||||
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
||||
segmentation maps or list of label filenames.
|
||||
num_classes (int): Number of categories.
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
nan_to_num (int, optional): If specified, NaN values will be replaced
|
||||
|
@ -186,8 +198,10 @@ def eval_metrics(results,
|
|||
reduce_zero_label=False):
|
||||
"""Calculate evaluation metrics
|
||||
Args:
|
||||
results (list[ndarray]): List of prediction segmentation maps.
|
||||
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
|
||||
results (list[ndarray] | list[str]): List of prediction segmentation
|
||||
maps or list of prediction result filenames.
|
||||
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
||||
segmentation maps or list of label filenames.
|
||||
num_classes (int): Number of categories.
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
|
||||
|
@ -200,17 +214,16 @@ def eval_metrics(results,
|
|||
ndarray: Per category accuracy, shape (num_classes, ).
|
||||
ndarray: Per category evalution metrics, shape (num_classes, ).
|
||||
"""
|
||||
|
||||
if isinstance(metrics, str):
|
||||
metrics = [metrics]
|
||||
allowed_metrics = ['mIoU', 'mDice']
|
||||
if not set(metrics).issubset(set(allowed_metrics)):
|
||||
raise KeyError('metrics {} is not supported'.format(metrics))
|
||||
|
||||
total_area_intersect, total_area_union, total_area_pred_label, \
|
||||
total_area_label = total_intersect_and_union(results, gt_seg_maps,
|
||||
num_classes, ignore_index,
|
||||
label_map,
|
||||
reduce_zero_label)
|
||||
total_area_label = total_intersect_and_union(
|
||||
results, gt_seg_maps, num_classes, ignore_index, label_map,
|
||||
reduce_zero_label)
|
||||
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
||||
acc = total_area_intersect / total_area_label
|
||||
ret_metrics = [all_acc, acc]
|
||||
|
@ -222,6 +235,7 @@ def eval_metrics(results,
|
|||
dice = 2 * total_area_intersect / (
|
||||
total_area_pred_label + total_area_label)
|
||||
ret_metrics.append(dice)
|
||||
ret_metrics = [metric.numpy() for metric in ret_metrics]
|
||||
if nan_to_num is not None:
|
||||
ret_metrics = [
|
||||
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics
|
||||
|
|
|
@ -164,3 +164,45 @@ def test_mean_dice():
|
|||
results, label, num_classes, ignore_index=255, nan_to_num=-1)
|
||||
assert acc[-1] == -1
|
||||
assert iou[-1] == -1
|
||||
|
||||
|
||||
def test_filename_inputs():
|
||||
import cv2
|
||||
import tempfile
|
||||
|
||||
def save_arr(input_arrays: list, title: str, is_image: bool, dir: str):
|
||||
filenames = []
|
||||
SUFFIX = '.png' if is_image else '.npy'
|
||||
for idx, arr in enumerate(input_arrays):
|
||||
filename = '{}/{}-{}{}'.format(dir, title, idx, SUFFIX)
|
||||
if is_image:
|
||||
cv2.imwrite(filename, arr)
|
||||
else:
|
||||
np.save(filename, arr)
|
||||
filenames.append(filename)
|
||||
return filenames
|
||||
|
||||
pred_size = (10, 512, 1024)
|
||||
num_classes = 19
|
||||
ignore_index = 255
|
||||
results = np.random.randint(0, num_classes, size=pred_size)
|
||||
labels = np.random.randint(0, num_classes, size=pred_size)
|
||||
labels[:, 2, 5:10] = ignore_index
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
|
||||
result_files = save_arr(results, 'pred', False, temp_dir)
|
||||
label_files = save_arr(labels, 'label', True, temp_dir)
|
||||
|
||||
all_acc, acc, iou = eval_metrics(
|
||||
result_files,
|
||||
label_files,
|
||||
num_classes,
|
||||
ignore_index,
|
||||
metrics='mIoU')
|
||||
|
||||
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, labels, num_classes,
|
||||
ignore_index)
|
||||
assert all_acc == all_acc_l
|
||||
assert np.allclose(acc, acc_l)
|
||||
assert np.allclose(iou, iou_l)
|
||||
|
|
Loading…
Reference in New Issue