add metric mFscore (#509)

* add mFscore and refactor the metrics return value

* fix linting

* some docstring and name fix
This commit is contained in:
sshuair 2021-05-01 02:34:57 +08:00 committed by GitHub
parent cf2cb542f7
commit e16e0e303b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 319 additions and 86 deletions

View File

@ -1,8 +1,8 @@
from .class_names import get_classes, get_palette from .class_names import get_classes, get_palette
from .eval_hooks import DistEvalHook, EvalHook from .eval_hooks import DistEvalHook, EvalHook
from .metrics import eval_metrics, mean_dice, mean_iou from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou
__all__ = [ __all__ = [
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics', 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
'get_classes', 'get_palette' 'eval_metrics', 'get_classes', 'get_palette'
] ]

View File

@ -1,8 +1,27 @@
from collections import OrderedDict
import mmcv import mmcv
import numpy as np import numpy as np
import torch import torch
def f_score(precision, recall, beta=1):
"""calcuate the f-score value.
Args:
precision (float | torch.Tensor): The precision value.
recall (float | torch.Tensor): The recall value.
beta (int): Determines the weight of recall in the combined score.
Default: False.
Returns:
[torch.tensor]: The f-score value.
"""
score = (1 + beta**2) * (precision * recall) / (
(beta**2 * precision) + recall)
return score
def intersect_and_union(pred_label, def intersect_and_union(pred_label,
label, label,
num_classes, num_classes,
@ -133,11 +152,12 @@ def mean_iou(results,
reduce_zero_label (bool): Wether ignore zero label. Default: False. reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns: Returns:
float: Overall accuracy on all images. dict[str, float | ndarray]:
ndarray: Per category accuracy, shape (num_classes, ). <aAcc> float: Overall accuracy on all images.
ndarray: Per category IoU, shape (num_classes, ). <Acc> ndarray: Per category accuracy, shape (num_classes, ).
<IoU> ndarray: Per category IoU, shape (num_classes, ).
""" """
all_acc, acc, iou = eval_metrics( iou_result = eval_metrics(
results=results, results=results,
gt_seg_maps=gt_seg_maps, gt_seg_maps=gt_seg_maps,
num_classes=num_classes, num_classes=num_classes,
@ -146,7 +166,7 @@ def mean_iou(results,
nan_to_num=nan_to_num, nan_to_num=nan_to_num,
label_map=label_map, label_map=label_map,
reduce_zero_label=reduce_zero_label) reduce_zero_label=reduce_zero_label)
return all_acc, acc, iou return iou_result
def mean_dice(results, def mean_dice(results,
@ -171,12 +191,13 @@ def mean_dice(results,
reduce_zero_label (bool): Wether ignore zero label. Default: False. reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns: Returns:
float: Overall accuracy on all images. dict[str, float | ndarray]: Default metrics.
ndarray: Per category accuracy, shape (num_classes, ). <aAcc> float: Overall accuracy on all images.
ndarray: Per category dice, shape (num_classes, ). <Acc> ndarray: Per category accuracy, shape (num_classes, ).
<Dice> ndarray: Per category dice, shape (num_classes, ).
""" """
all_acc, acc, dice = eval_metrics( dice_result = eval_metrics(
results=results, results=results,
gt_seg_maps=gt_seg_maps, gt_seg_maps=gt_seg_maps,
num_classes=num_classes, num_classes=num_classes,
@ -185,7 +206,52 @@ def mean_dice(results,
nan_to_num=nan_to_num, nan_to_num=nan_to_num,
label_map=label_map, label_map=label_map,
reduce_zero_label=reduce_zero_label) reduce_zero_label=reduce_zero_label)
return all_acc, acc, dice return dice_result
def mean_fscore(results,
gt_seg_maps,
num_classes,
ignore_index,
nan_to_num=None,
label_map=dict(),
reduce_zero_label=False,
beta=1):
"""Calculate Mean Intersection and Union (mIoU)
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
beta (int): Determines the weight of recall in the combined score.
Default: False.
Returns:
dict[str, float | ndarray]: Default metrics.
<aAcc> float: Overall accuracy on all images.
<Fscore> ndarray: Per category recall, shape (num_classes, ).
<Precision> ndarray: Per category precision, shape (num_classes, ).
<Recall> ndarray: Per category f-score, shape (num_classes, ).
"""
fscore_result = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
num_classes=num_classes,
ignore_index=ignore_index,
metrics=['mFscore'],
nan_to_num=nan_to_num,
label_map=label_map,
reduce_zero_label=reduce_zero_label,
beta=beta)
return fscore_result
def eval_metrics(results, def eval_metrics(results,
@ -195,7 +261,8 @@ def eval_metrics(results,
metrics=['mIoU'], metrics=['mIoU'],
nan_to_num=None, nan_to_num=None,
label_map=dict(), label_map=dict(),
reduce_zero_label=False): reduce_zero_label=False,
beta=1):
"""Calculate evaluation metrics """Calculate evaluation metrics
Args: Args:
results (list[ndarray] | list[str]): List of prediction segmentation results (list[ndarray] | list[str]): List of prediction segmentation
@ -210,13 +277,13 @@ def eval_metrics(results,
label_map (dict): Mapping old labels to new labels. Default: dict(). label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False. reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns: Returns:
float: Overall accuracy on all images. float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ). ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evaluation metrics, shape (num_classes, ). ndarray: Per category evaluation metrics, shape (num_classes, ).
""" """
if isinstance(metrics, str): if isinstance(metrics, str):
metrics = [metrics] metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice'] allowed_metrics = ['mIoU', 'mDice', 'mFscore']
if not set(metrics).issubset(set(allowed_metrics)): if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError('metrics {} is not supported'.format(metrics)) raise KeyError('metrics {} is not supported'.format(metrics))
@ -225,19 +292,35 @@ def eval_metrics(results,
results, gt_seg_maps, num_classes, ignore_index, label_map, results, gt_seg_maps, num_classes, ignore_index, label_map,
reduce_zero_label) reduce_zero_label)
all_acc = total_area_intersect.sum() / total_area_label.sum() all_acc = total_area_intersect.sum() / total_area_label.sum()
acc = total_area_intersect / total_area_label ret_metrics = OrderedDict({'aAcc': all_acc})
ret_metrics = [all_acc, acc]
for metric in metrics: for metric in metrics:
if metric == 'mIoU': if metric == 'mIoU':
iou = total_area_intersect / total_area_union iou = total_area_intersect / total_area_union
ret_metrics.append(iou) acc = total_area_intersect / total_area_label
ret_metrics['IoU'] = iou
ret_metrics['Acc'] = acc
elif metric == 'mDice': elif metric == 'mDice':
dice = 2 * total_area_intersect / ( dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label) total_area_pred_label + total_area_label)
ret_metrics.append(dice) acc = total_area_intersect / total_area_label
ret_metrics = [metric.numpy() for metric in ret_metrics] ret_metrics['Dice'] = dice
ret_metrics['Acc'] = acc
elif metric == 'mFscore':
precision = total_area_intersect / total_area_pred_label
recall = total_area_intersect / total_area_label
f_value = torch.tensor(
[f_score(x[0], x[1], beta) for x in zip(precision, recall)])
ret_metrics['Fscore'] = f_value
ret_metrics['Precision'] = precision
ret_metrics['Recall'] = recall
ret_metrics = {
metric: value.numpy()
for metric, value in ret_metrics.items()
}
if nan_to_num is not None: if nan_to_num is not None:
ret_metrics = [ ret_metrics = OrderedDict({
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics metric: np.nan_to_num(metric_value, nan=nan_to_num)
] for metric, metric_value in ret_metrics.items()
})
return ret_metrics return ret_metrics

View File

@ -1,11 +1,12 @@
import os import os
import os.path as osp import os.path as osp
from collections import OrderedDict
from functools import reduce from functools import reduce
import mmcv import mmcv
import numpy as np import numpy as np
from mmcv.utils import print_log from mmcv.utils import print_log
from terminaltables import AsciiTable from prettytable import PrettyTable
from torch.utils.data import Dataset from torch.utils.data import Dataset
from mmseg.core import eval_metrics from mmseg.core import eval_metrics
@ -312,8 +313,8 @@ class CustomDataset(Dataset):
Args: Args:
results (list): Testing results of the dataset. results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. 'mIoU' and metric (str | list[str]): Metrics to be evaluated. 'mIoU',
'mDice' are supported. 'mDice' and 'mFscore' are supported.
logger (logging.Logger | None | str): Logger used for printing logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
@ -323,7 +324,7 @@ class CustomDataset(Dataset):
if isinstance(metric, str): if isinstance(metric, str):
metric = [metric] metric = [metric]
allowed_metrics = ['mIoU', 'mDice'] allowed_metrics = ['mIoU', 'mDice', 'mFscore']
if not set(metric).issubset(set(allowed_metrics)): if not set(metric).issubset(set(allowed_metrics)):
raise KeyError('metric {} is not supported'.format(metric)) raise KeyError('metric {} is not supported'.format(metric))
eval_results = {} eval_results = {}
@ -341,42 +342,57 @@ class CustomDataset(Dataset):
metric, metric,
label_map=self.label_map, label_map=self.label_map,
reduce_zero_label=self.reduce_zero_label) reduce_zero_label=self.reduce_zero_label)
class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']]
if self.CLASSES is None: if self.CLASSES is None:
class_names = tuple(range(num_classes)) class_names = tuple(range(num_classes))
else: else:
class_names = self.CLASSES class_names = self.CLASSES
ret_metrics_round = [
np.round(ret_metric * 100, 2) for ret_metric in ret_metrics
]
for i in range(num_classes):
class_table_data.append([class_names[i]] +
[m[i] for m in ret_metrics_round[2:]] +
[ret_metrics_round[1][i]])
summary_table_data = [['Scope'] +
['m' + head
for head in class_table_data[0][1:]] + ['aAcc']]
ret_metrics_mean = [
np.round(np.nanmean(ret_metric) * 100, 2)
for ret_metric in ret_metrics
]
summary_table_data.append(['global'] + ret_metrics_mean[2:] +
[ret_metrics_mean[1]] +
[ret_metrics_mean[0]])
print_log('per class results:', logger)
table = AsciiTable(class_table_data)
print_log('\n' + table.table, logger=logger)
print_log('Summary:', logger)
table = AsciiTable(summary_table_data)
print_log('\n' + table.table, logger=logger)
for i in range(1, len(summary_table_data[0])): # summary table
eval_results[summary_table_data[0] ret_metrics_summary = OrderedDict({
[i]] = summary_table_data[1][i] / 100.0 ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
for idx, sub_metric in enumerate(class_table_data[0][1:], 1): for ret_metric, ret_metric_value in ret_metrics.items()
for item in class_table_data[1:]: })
eval_results[str(sub_metric) + '.' +
str(item[0])] = item[idx] / 100.0 # each class table
ret_metrics.pop('aAcc', None)
ret_metrics_class = OrderedDict({
ret_metric: np.round(ret_metric_value * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})
ret_metrics_class.update({'Class': class_names})
ret_metrics_class.move_to_end('Class', last=False)
# for logger
class_table_data = PrettyTable()
for key, val in ret_metrics_class.items():
class_table_data.add_column(key, val)
summary_table_data = PrettyTable()
for key, val in ret_metrics_summary.items():
if key == 'aAcc':
summary_table_data.add_column(key, [val])
else:
summary_table_data.add_column('m' + key, [val])
print_log('per class results:', logger)
print_log('\n' + class_table_data.get_string(), logger=logger)
print_log('Summary:', logger)
print_log('\n' + summary_table_data.get_string(), logger=logger)
# each metric dict
for key, value in ret_metrics_summary.items():
if key == 'aAcc':
eval_results[key] = value / 100.0
else:
eval_results['m' + key] = value / 100.0
ret_metrics_class.pop('Class', None)
for key, value in ret_metrics_class.items():
eval_results.update({
key + '.' + str(name): value[idx] / 100.0
for idx, name in enumerate(class_names)
})
if mmcv.is_list_of(results, str): if mmcv.is_list_of(results, str):
for file_name in results: for file_name in results:

View File

@ -1,3 +1,3 @@
matplotlib matplotlib
numpy numpy
terminaltables prettytable

View File

@ -8,6 +8,6 @@ line_length = 79
multi_line_output = 0 multi_line_output = 0
known_standard_library = setuptools known_standard_library = setuptools
known_first_party = mmseg known_first_party = mmseg
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,seaborn,terminaltables,torch known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch
no_lines_before = STDLIB,LOCALFOLDER no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY default_section = THIRDPARTY

View File

@ -159,7 +159,7 @@ def test_custom_dataset():
for gt_seg_map in gt_seg_maps: for gt_seg_map in gt_seg_maps:
h, w = gt_seg_map.shape h, w = gt_seg_map.shape
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w))) pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU') eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
assert isinstance(eval_results, dict) assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results assert 'mIoU' in eval_results
assert 'mAcc' in eval_results assert 'mAcc' in eval_results
@ -193,13 +193,23 @@ def test_custom_dataset():
assert 'mAcc' in eval_results assert 'mAcc' in eval_results
assert 'aAcc' in eval_results assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore')
assert isinstance(eval_results, dict)
assert 'mRecall' in eval_results
assert 'mPrecision' in eval_results
assert 'mFscore' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate( eval_results = train_dataset.evaluate(
pseudo_results, metric=['mIoU', 'mDice']) pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
assert isinstance(eval_results, dict) assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results assert 'mIoU' in eval_results
assert 'mDice' in eval_results assert 'mDice' in eval_results
assert 'mAcc' in eval_results assert 'mAcc' in eval_results
assert 'aAcc' in eval_results assert 'aAcc' in eval_results
assert 'mFscore' in eval_results
assert 'mPrecision' in eval_results
assert 'mRecall' in eval_results
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) @patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)

View File

@ -1,6 +1,8 @@
import numpy as np import numpy as np
from mmseg.core.evaluation import eval_metrics, mean_dice, mean_iou from mmseg.core.evaluation import (eval_metrics, mean_dice, mean_fscore,
mean_iou)
from mmseg.core.evaluation.metrics import f_score
def get_confusion_matrix(pred_label, label, num_classes, ignore_index): def get_confusion_matrix(pred_label, label, num_classes, ignore_index):
@ -58,6 +60,28 @@ def legacy_mean_dice(results, gt_seg_maps, num_classes, ignore_index):
return all_acc, acc, dice return all_acc, acc, dice
# This func is deprecated since it's not memory efficient
def legacy_mean_fscore(results,
gt_seg_maps,
num_classes,
ignore_index,
beta=1):
num_imgs = len(results)
assert len(gt_seg_maps) == num_imgs
total_mat = np.zeros((num_classes, num_classes), dtype=np.float)
for i in range(num_imgs):
mat = get_confusion_matrix(
results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index)
total_mat += mat
all_acc = np.diag(total_mat).sum() / total_mat.sum()
recall = np.diag(total_mat) / total_mat.sum(axis=1)
precision = np.diag(total_mat) / total_mat.sum(axis=0)
fv = np.vectorize(f_score)
fscore = fv(precision, recall, beta=beta)
return all_acc, recall, precision, fscore
def test_metrics(): def test_metrics():
pred_size = (10, 30, 30) pred_size = (10, 30, 30)
num_classes = 19 num_classes = 19
@ -69,63 +93,113 @@ def test_metrics():
label[:, 2, 5:10] = ignore_index label[:, 2, 5:10] = ignore_index
# Test the correctness of the implementation of mIoU calculation. # Test the correctness of the implementation of mIoU calculation.
all_acc, acc, iou = eval_metrics( ret_metrics = eval_metrics(
results, label, num_classes, ignore_index, metrics='mIoU') results, label, num_classes, ignore_index, metrics='mIoU')
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'IoU']
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes, all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
ignore_index) ignore_index)
assert all_acc == all_acc_l assert all_acc == all_acc_l
assert np.allclose(acc, acc_l) assert np.allclose(acc, acc_l)
assert np.allclose(iou, iou_l) assert np.allclose(iou, iou_l)
# Test the correctness of the implementation of mDice calculation. # Test the correctness of the implementation of mDice calculation.
all_acc, acc, dice = eval_metrics( ret_metrics = eval_metrics(
results, label, num_classes, ignore_index, metrics='mDice') results, label, num_classes, ignore_index, metrics='mDice')
all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'Dice']
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes, all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
ignore_index) ignore_index)
assert all_acc == all_acc_l assert all_acc == all_acc_l
assert np.allclose(acc, acc_l) assert np.allclose(acc, acc_l)
assert np.allclose(dice, dice_l) assert np.allclose(dice, dice_l)
# Test the correctness of the implementation of mDice calculation.
ret_metrics = eval_metrics(
results, label, num_classes, ignore_index, metrics='mFscore')
all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore(
results, label, num_classes, ignore_index)
assert all_acc == all_acc_l
assert np.allclose(recall, recall_l)
assert np.allclose(precision, precision_l)
assert np.allclose(fscore, fscore_l)
# Test the correctness of the implementation of joint calculation. # Test the correctness of the implementation of joint calculation.
all_acc, acc, iou, dice = eval_metrics( ret_metrics = eval_metrics(
results, label, num_classes, ignore_index, metrics=['mIoU', 'mDice']) results,
label,
num_classes,
ignore_index,
metrics=['mIoU', 'mDice', 'mFscore'])
all_acc, acc, iou, dice, precision, recall, fscore = ret_metrics[
'aAcc'], ret_metrics['Acc'], ret_metrics['IoU'], ret_metrics[
'Dice'], ret_metrics['Precision'], ret_metrics[
'Recall'], ret_metrics['Fscore']
assert all_acc == all_acc_l assert all_acc == all_acc_l
assert np.allclose(acc, acc_l) assert np.allclose(acc, acc_l)
assert np.allclose(iou, iou_l) assert np.allclose(iou, iou_l)
assert np.allclose(dice, dice_l) assert np.allclose(dice, dice_l)
assert np.allclose(precision, precision_l)
assert np.allclose(recall, recall_l)
assert np.allclose(fscore, fscore_l)
# Test the correctness of calculation when arg: num_classes is larger # Test the correctness of calculation when arg: num_classes is larger
# than the maximum value of input maps. # than the maximum value of input maps.
results = np.random.randint(0, 5, size=pred_size) results = np.random.randint(0, 5, size=pred_size)
label = np.random.randint(0, 4, size=pred_size) label = np.random.randint(0, 4, size=pred_size)
all_acc, acc, iou = eval_metrics( ret_metrics = eval_metrics(
results, results,
label, label,
num_classes, num_classes,
ignore_index=255, ignore_index=255,
metrics='mIoU', metrics='mIoU',
nan_to_num=-1) nan_to_num=-1)
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'IoU']
assert acc[-1] == -1 assert acc[-1] == -1
assert iou[-1] == -1 assert iou[-1] == -1
all_acc, acc, dice = eval_metrics( ret_metrics = eval_metrics(
results, results,
label, label,
num_classes, num_classes,
ignore_index=255, ignore_index=255,
metrics='mDice', metrics='mDice',
nan_to_num=-1) nan_to_num=-1)
all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'Dice']
assert acc[-1] == -1 assert acc[-1] == -1
assert dice[-1] == -1 assert dice[-1] == -1
all_acc, acc, dice, iou = eval_metrics( ret_metrics = eval_metrics(
results, results,
label, label,
num_classes, num_classes,
ignore_index=255, ignore_index=255,
metrics=['mDice', 'mIoU'], metrics='mFscore',
nan_to_num=-1) nan_to_num=-1)
all_acc, precision, recall, fscore = ret_metrics['aAcc'], ret_metrics[
'Precision'], ret_metrics['Recall'], ret_metrics['Fscore']
assert precision[-1] == -1
assert recall[-1] == -1
assert fscore[-1] == -1
ret_metrics = eval_metrics(
results,
label,
num_classes,
ignore_index=255,
metrics=['mDice', 'mIoU', 'mFscore'],
nan_to_num=-1)
all_acc, acc, iou, dice, precision, recall, fscore = ret_metrics[
'aAcc'], ret_metrics['Acc'], ret_metrics['IoU'], ret_metrics[
'Dice'], ret_metrics['Precision'], ret_metrics[
'Recall'], ret_metrics['Fscore']
assert acc[-1] == -1 assert acc[-1] == -1
assert dice[-1] == -1 assert dice[-1] == -1
assert iou[-1] == -1 assert iou[-1] == -1
assert precision[-1] == -1
assert recall[-1] == -1
assert fscore[-1] == -1
# Test the bug which is caused by torch.histc. # Test the bug which is caused by torch.histc.
# torch.histc: https://pytorch.org/docs/stable/generated/torch.histc.html # torch.histc: https://pytorch.org/docs/stable/generated/torch.histc.html
@ -134,8 +208,10 @@ def test_metrics():
results = np.array([np.repeat(31, 59)]) results = np.array([np.repeat(31, 59)])
label = np.array([np.arange(59)]) label = np.array([np.arange(59)])
num_classes = 59 num_classes = 59
all_acc, acc, iou = eval_metrics( ret_metrics = eval_metrics(
results, label, num_classes, ignore_index=255, metrics='mIoU') results, label, num_classes, ignore_index=255, metrics='mIoU')
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'IoU']
assert not np.any(np.isnan(iou)) assert not np.any(np.isnan(iou))
@ -146,7 +222,9 @@ def test_mean_iou():
results = np.random.randint(0, num_classes, size=pred_size) results = np.random.randint(0, num_classes, size=pred_size)
label = np.random.randint(0, num_classes, size=pred_size) label = np.random.randint(0, num_classes, size=pred_size)
label[:, 2, 5:10] = ignore_index label[:, 2, 5:10] = ignore_index
all_acc, acc, iou = mean_iou(results, label, num_classes, ignore_index) ret_metrics = mean_iou(results, label, num_classes, ignore_index)
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'IoU']
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes, all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
ignore_index) ignore_index)
assert all_acc == all_acc_l assert all_acc == all_acc_l
@ -155,10 +233,12 @@ def test_mean_iou():
results = np.random.randint(0, 5, size=pred_size) results = np.random.randint(0, 5, size=pred_size)
label = np.random.randint(0, 4, size=pred_size) label = np.random.randint(0, 4, size=pred_size)
all_acc, acc, iou = mean_iou( ret_metrics = mean_iou(
results, label, num_classes, ignore_index=255, nan_to_num=-1) results, label, num_classes, ignore_index=255, nan_to_num=-1)
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'IoU']
assert acc[-1] == -1
assert acc[-1] == -1 assert acc[-1] == -1
assert iou[-1] == -1
def test_mean_dice(): def test_mean_dice():
@ -168,19 +248,62 @@ def test_mean_dice():
results = np.random.randint(0, num_classes, size=pred_size) results = np.random.randint(0, num_classes, size=pred_size)
label = np.random.randint(0, num_classes, size=pred_size) label = np.random.randint(0, num_classes, size=pred_size)
label[:, 2, 5:10] = ignore_index label[:, 2, 5:10] = ignore_index
all_acc, acc, iou = mean_dice(results, label, num_classes, ignore_index) ret_metrics = mean_dice(results, label, num_classes, ignore_index)
all_acc_l, acc_l, iou_l = legacy_mean_dice(results, label, num_classes, all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
ignore_index) 'Dice']
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
ignore_index)
assert all_acc == all_acc_l assert all_acc == all_acc_l
assert np.allclose(acc, acc_l) assert np.allclose(acc, acc_l)
assert np.allclose(iou, iou_l) assert np.allclose(iou, dice_l)
results = np.random.randint(0, 5, size=pred_size) results = np.random.randint(0, 5, size=pred_size)
label = np.random.randint(0, 4, size=pred_size) label = np.random.randint(0, 4, size=pred_size)
all_acc, acc, iou = mean_dice( ret_metrics = mean_dice(
results, label, num_classes, ignore_index=255, nan_to_num=-1) results, label, num_classes, ignore_index=255, nan_to_num=-1)
all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
'Dice']
assert acc[-1] == -1 assert acc[-1] == -1
assert iou[-1] == -1 assert dice[-1] == -1
def test_mean_fscore():
pred_size = (10, 30, 30)
num_classes = 19
ignore_index = 255
results = np.random.randint(0, num_classes, size=pred_size)
label = np.random.randint(0, num_classes, size=pred_size)
label[:, 2, 5:10] = ignore_index
ret_metrics = mean_fscore(results, label, num_classes, ignore_index)
all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore(
results, label, num_classes, ignore_index)
assert all_acc == all_acc_l
assert np.allclose(recall, recall_l)
assert np.allclose(precision, precision_l)
assert np.allclose(fscore, fscore_l)
ret_metrics = mean_fscore(
results, label, num_classes, ignore_index, beta=2)
all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore(
results, label, num_classes, ignore_index, beta=2)
assert all_acc == all_acc_l
assert np.allclose(recall, recall_l)
assert np.allclose(precision, precision_l)
assert np.allclose(fscore, fscore_l)
results = np.random.randint(0, 5, size=pred_size)
label = np.random.randint(0, 4, size=pred_size)
ret_metrics = mean_fscore(
results, label, num_classes, ignore_index=255, nan_to_num=-1)
all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
assert recall[-1] == -1
assert precision[-1] == -1
assert fscore[-1] == -1
def test_filename_inputs(): def test_filename_inputs():
@ -211,13 +334,14 @@ def test_filename_inputs():
result_files = save_arr(results, 'pred', False, temp_dir) result_files = save_arr(results, 'pred', False, temp_dir)
label_files = save_arr(labels, 'label', True, temp_dir) label_files = save_arr(labels, 'label', True, temp_dir)
all_acc, acc, iou = eval_metrics( ret_metrics = eval_metrics(
result_files, result_files,
label_files, label_files,
num_classes, num_classes,
ignore_index, ignore_index,
metrics='mIoU') metrics='mIoU')
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics[
'Acc'], ret_metrics['IoU']
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, labels, num_classes, all_acc_l, acc_l, iou_l = legacy_mean_iou(results, labels, num_classes,
ignore_index) ignore_index)
assert all_acc == all_acc_l assert all_acc == all_acc_l