mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
add metric mFscore (#509)
* add mFscore and refactor the metrics return value * fix linting * some docstring and name fix
This commit is contained in:
parent
cf2cb542f7
commit
e16e0e303b
@ -1,8 +1,8 @@
|
|||||||
from .class_names import get_classes, get_palette
|
from .class_names import get_classes, get_palette
|
||||||
from .eval_hooks import DistEvalHook, EvalHook
|
from .eval_hooks import DistEvalHook, EvalHook
|
||||||
from .metrics import eval_metrics, mean_dice, mean_iou
|
from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics',
|
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
|
||||||
'get_classes', 'get_palette'
|
'eval_metrics', 'get_classes', 'get_palette'
|
||||||
]
|
]
|
||||||
|
@ -1,8 +1,27 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def f_score(precision, recall, beta=1):
|
||||||
|
"""calcuate the f-score value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
precision (float | torch.Tensor): The precision value.
|
||||||
|
recall (float | torch.Tensor): The recall value.
|
||||||
|
beta (int): Determines the weight of recall in the combined score.
|
||||||
|
Default: False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[torch.tensor]: The f-score value.
|
||||||
|
"""
|
||||||
|
score = (1 + beta**2) * (precision * recall) / (
|
||||||
|
(beta**2 * precision) + recall)
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
def intersect_and_union(pred_label,
|
def intersect_and_union(pred_label,
|
||||||
label,
|
label,
|
||||||
num_classes,
|
num_classes,
|
||||||
@ -133,11 +152,12 @@ def mean_iou(results,
|
|||||||
reduce_zero_label (bool): Wether ignore zero label. Default: False.
|
reduce_zero_label (bool): Wether ignore zero label. Default: False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
float: Overall accuracy on all images.
|
dict[str, float | ndarray]:
|
||||||
ndarray: Per category accuracy, shape (num_classes, ).
|
<aAcc> float: Overall accuracy on all images.
|
||||||
ndarray: Per category IoU, shape (num_classes, ).
|
<Acc> ndarray: Per category accuracy, shape (num_classes, ).
|
||||||
|
<IoU> ndarray: Per category IoU, shape (num_classes, ).
|
||||||
"""
|
"""
|
||||||
all_acc, acc, iou = eval_metrics(
|
iou_result = eval_metrics(
|
||||||
results=results,
|
results=results,
|
||||||
gt_seg_maps=gt_seg_maps,
|
gt_seg_maps=gt_seg_maps,
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
@ -146,7 +166,7 @@ def mean_iou(results,
|
|||||||
nan_to_num=nan_to_num,
|
nan_to_num=nan_to_num,
|
||||||
label_map=label_map,
|
label_map=label_map,
|
||||||
reduce_zero_label=reduce_zero_label)
|
reduce_zero_label=reduce_zero_label)
|
||||||
return all_acc, acc, iou
|
return iou_result
|
||||||
|
|
||||||
|
|
||||||
def mean_dice(results,
|
def mean_dice(results,
|
||||||
@ -171,12 +191,13 @@ def mean_dice(results,
|
|||||||
reduce_zero_label (bool): Wether ignore zero label. Default: False.
|
reduce_zero_label (bool): Wether ignore zero label. Default: False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
float: Overall accuracy on all images.
|
dict[str, float | ndarray]: Default metrics.
|
||||||
ndarray: Per category accuracy, shape (num_classes, ).
|
<aAcc> float: Overall accuracy on all images.
|
||||||
ndarray: Per category dice, shape (num_classes, ).
|
<Acc> ndarray: Per category accuracy, shape (num_classes, ).
|
||||||
|
<Dice> ndarray: Per category dice, shape (num_classes, ).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
all_acc, acc, dice = eval_metrics(
|
dice_result = eval_metrics(
|
||||||
results=results,
|
results=results,
|
||||||
gt_seg_maps=gt_seg_maps,
|
gt_seg_maps=gt_seg_maps,
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
@ -185,7 +206,52 @@ def mean_dice(results,
|
|||||||
nan_to_num=nan_to_num,
|
nan_to_num=nan_to_num,
|
||||||
label_map=label_map,
|
label_map=label_map,
|
||||||
reduce_zero_label=reduce_zero_label)
|
reduce_zero_label=reduce_zero_label)
|
||||||
return all_acc, acc, dice
|
return dice_result
|
||||||
|
|
||||||
|
|
||||||
|
def mean_fscore(results,
|
||||||
|
gt_seg_maps,
|
||||||
|
num_classes,
|
||||||
|
ignore_index,
|
||||||
|
nan_to_num=None,
|
||||||
|
label_map=dict(),
|
||||||
|
reduce_zero_label=False,
|
||||||
|
beta=1):
|
||||||
|
"""Calculate Mean Intersection and Union (mIoU)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (list[ndarray] | list[str]): List of prediction segmentation
|
||||||
|
maps or list of prediction result filenames.
|
||||||
|
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
||||||
|
segmentation maps or list of label filenames.
|
||||||
|
num_classes (int): Number of categories.
|
||||||
|
ignore_index (int): Index that will be ignored in evaluation.
|
||||||
|
nan_to_num (int, optional): If specified, NaN values will be replaced
|
||||||
|
by the numbers defined by the user. Default: None.
|
||||||
|
label_map (dict): Mapping old labels to new labels. Default: dict().
|
||||||
|
reduce_zero_label (bool): Wether ignore zero label. Default: False.
|
||||||
|
beta (int): Determines the weight of recall in the combined score.
|
||||||
|
Default: False.
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, float | ndarray]: Default metrics.
|
||||||
|
<aAcc> float: Overall accuracy on all images.
|
||||||
|
<Fscore> ndarray: Per category recall, shape (num_classes, ).
|
||||||
|
<Precision> ndarray: Per category precision, shape (num_classes, ).
|
||||||
|
<Recall> ndarray: Per category f-score, shape (num_classes, ).
|
||||||
|
"""
|
||||||
|
fscore_result = eval_metrics(
|
||||||
|
results=results,
|
||||||
|
gt_seg_maps=gt_seg_maps,
|
||||||
|
num_classes=num_classes,
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
metrics=['mFscore'],
|
||||||
|
nan_to_num=nan_to_num,
|
||||||
|
label_map=label_map,
|
||||||
|
reduce_zero_label=reduce_zero_label,
|
||||||
|
beta=beta)
|
||||||
|
return fscore_result
|
||||||
|
|
||||||
|
|
||||||
def eval_metrics(results,
|
def eval_metrics(results,
|
||||||
@ -195,7 +261,8 @@ def eval_metrics(results,
|
|||||||
metrics=['mIoU'],
|
metrics=['mIoU'],
|
||||||
nan_to_num=None,
|
nan_to_num=None,
|
||||||
label_map=dict(),
|
label_map=dict(),
|
||||||
reduce_zero_label=False):
|
reduce_zero_label=False,
|
||||||
|
beta=1):
|
||||||
"""Calculate evaluation metrics
|
"""Calculate evaluation metrics
|
||||||
Args:
|
Args:
|
||||||
results (list[ndarray] | list[str]): List of prediction segmentation
|
results (list[ndarray] | list[str]): List of prediction segmentation
|
||||||
@ -210,13 +277,13 @@ def eval_metrics(results,
|
|||||||
label_map (dict): Mapping old labels to new labels. Default: dict().
|
label_map (dict): Mapping old labels to new labels. Default: dict().
|
||||||
reduce_zero_label (bool): Wether ignore zero label. Default: False.
|
reduce_zero_label (bool): Wether ignore zero label. Default: False.
|
||||||
Returns:
|
Returns:
|
||||||
float: Overall accuracy on all images.
|
float: Overall accuracy on all images.
|
||||||
ndarray: Per category accuracy, shape (num_classes, ).
|
ndarray: Per category accuracy, shape (num_classes, ).
|
||||||
ndarray: Per category evaluation metrics, shape (num_classes, ).
|
ndarray: Per category evaluation metrics, shape (num_classes, ).
|
||||||
"""
|
"""
|
||||||
if isinstance(metrics, str):
|
if isinstance(metrics, str):
|
||||||
metrics = [metrics]
|
metrics = [metrics]
|
||||||
allowed_metrics = ['mIoU', 'mDice']
|
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
|
||||||
if not set(metrics).issubset(set(allowed_metrics)):
|
if not set(metrics).issubset(set(allowed_metrics)):
|
||||||
raise KeyError('metrics {} is not supported'.format(metrics))
|
raise KeyError('metrics {} is not supported'.format(metrics))
|
||||||
|
|
||||||
@ -225,19 +292,35 @@ def eval_metrics(results,
|
|||||||
results, gt_seg_maps, num_classes, ignore_index, label_map,
|
results, gt_seg_maps, num_classes, ignore_index, label_map,
|
||||||
reduce_zero_label)
|
reduce_zero_label)
|
||||||
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
||||||
acc = total_area_intersect / total_area_label
|
ret_metrics = OrderedDict({'aAcc': all_acc})
|
||||||
ret_metrics = [all_acc, acc]
|
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
if metric == 'mIoU':
|
if metric == 'mIoU':
|
||||||
iou = total_area_intersect / total_area_union
|
iou = total_area_intersect / total_area_union
|
||||||
ret_metrics.append(iou)
|
acc = total_area_intersect / total_area_label
|
||||||
|
ret_metrics['IoU'] = iou
|
||||||
|
ret_metrics['Acc'] = acc
|
||||||
elif metric == 'mDice':
|
elif metric == 'mDice':
|
||||||
dice = 2 * total_area_intersect / (
|
dice = 2 * total_area_intersect / (
|
||||||
total_area_pred_label + total_area_label)
|
total_area_pred_label + total_area_label)
|
||||||
ret_metrics.append(dice)
|
acc = total_area_intersect / total_area_label
|
||||||
ret_metrics = [metric.numpy() for metric in ret_metrics]
|
ret_metrics['Dice'] = dice
|
||||||
|
ret_metrics['Acc'] = acc
|
||||||
|
elif metric == 'mFscore':
|
||||||
|
precision = total_area_intersect / total_area_pred_label
|
||||||
|
recall = total_area_intersect / total_area_label
|
||||||
|
f_value = torch.tensor(
|
||||||
|
[f_score(x[0], x[1], beta) for x in zip(precision, recall)])
|
||||||
|
ret_metrics['Fscore'] = f_value
|
||||||
|
ret_metrics['Precision'] = precision
|
||||||
|
ret_metrics['Recall'] = recall
|
||||||
|
|
||||||
|
ret_metrics = {
|
||||||
|
metric: value.numpy()
|
||||||
|
for metric, value in ret_metrics.items()
|
||||||
|
}
|
||||||
if nan_to_num is not None:
|
if nan_to_num is not None:
|
||||||
ret_metrics = [
|
ret_metrics = OrderedDict({
|
||||||
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics
|
metric: np.nan_to_num(metric_value, nan=nan_to_num)
|
||||||
]
|
for metric, metric_value in ret_metrics.items()
|
||||||
|
})
|
||||||
return ret_metrics
|
return ret_metrics
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
from collections import OrderedDict
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mmcv.utils import print_log
|
from mmcv.utils import print_log
|
||||||
from terminaltables import AsciiTable
|
from prettytable import PrettyTable
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from mmseg.core import eval_metrics
|
from mmseg.core import eval_metrics
|
||||||
@ -312,8 +313,8 @@ class CustomDataset(Dataset):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
results (list): Testing results of the dataset.
|
results (list): Testing results of the dataset.
|
||||||
metric (str | list[str]): Metrics to be evaluated. 'mIoU' and
|
metric (str | list[str]): Metrics to be evaluated. 'mIoU',
|
||||||
'mDice' are supported.
|
'mDice' and 'mFscore' are supported.
|
||||||
logger (logging.Logger | None | str): Logger used for printing
|
logger (logging.Logger | None | str): Logger used for printing
|
||||||
related information during evaluation. Default: None.
|
related information during evaluation. Default: None.
|
||||||
|
|
||||||
@ -323,7 +324,7 @@ class CustomDataset(Dataset):
|
|||||||
|
|
||||||
if isinstance(metric, str):
|
if isinstance(metric, str):
|
||||||
metric = [metric]
|
metric = [metric]
|
||||||
allowed_metrics = ['mIoU', 'mDice']
|
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
|
||||||
if not set(metric).issubset(set(allowed_metrics)):
|
if not set(metric).issubset(set(allowed_metrics)):
|
||||||
raise KeyError('metric {} is not supported'.format(metric))
|
raise KeyError('metric {} is not supported'.format(metric))
|
||||||
eval_results = {}
|
eval_results = {}
|
||||||
@ -341,42 +342,57 @@ class CustomDataset(Dataset):
|
|||||||
metric,
|
metric,
|
||||||
label_map=self.label_map,
|
label_map=self.label_map,
|
||||||
reduce_zero_label=self.reduce_zero_label)
|
reduce_zero_label=self.reduce_zero_label)
|
||||||
class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']]
|
|
||||||
if self.CLASSES is None:
|
if self.CLASSES is None:
|
||||||
class_names = tuple(range(num_classes))
|
class_names = tuple(range(num_classes))
|
||||||
else:
|
else:
|
||||||
class_names = self.CLASSES
|
class_names = self.CLASSES
|
||||||
ret_metrics_round = [
|
|
||||||
np.round(ret_metric * 100, 2) for ret_metric in ret_metrics
|
|
||||||
]
|
|
||||||
for i in range(num_classes):
|
|
||||||
class_table_data.append([class_names[i]] +
|
|
||||||
[m[i] for m in ret_metrics_round[2:]] +
|
|
||||||
[ret_metrics_round[1][i]])
|
|
||||||
summary_table_data = [['Scope'] +
|
|
||||||
['m' + head
|
|
||||||
for head in class_table_data[0][1:]] + ['aAcc']]
|
|
||||||
ret_metrics_mean = [
|
|
||||||
np.round(np.nanmean(ret_metric) * 100, 2)
|
|
||||||
for ret_metric in ret_metrics
|
|
||||||
]
|
|
||||||
summary_table_data.append(['global'] + ret_metrics_mean[2:] +
|
|
||||||
[ret_metrics_mean[1]] +
|
|
||||||
[ret_metrics_mean[0]])
|
|
||||||
print_log('per class results:', logger)
|
|
||||||
table = AsciiTable(class_table_data)
|
|
||||||
print_log('\n' + table.table, logger=logger)
|
|
||||||
print_log('Summary:', logger)
|
|
||||||
table = AsciiTable(summary_table_data)
|
|
||||||
print_log('\n' + table.table, logger=logger)
|
|
||||||
|
|
||||||
for i in range(1, len(summary_table_data[0])):
|
# summary table
|
||||||
eval_results[summary_table_data[0]
|
ret_metrics_summary = OrderedDict({
|
||||||
[i]] = summary_table_data[1][i] / 100.0
|
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
|
||||||
for idx, sub_metric in enumerate(class_table_data[0][1:], 1):
|
for ret_metric, ret_metric_value in ret_metrics.items()
|
||||||
for item in class_table_data[1:]:
|
})
|
||||||
eval_results[str(sub_metric) + '.' +
|
|
||||||
str(item[0])] = item[idx] / 100.0
|
# each class table
|
||||||
|
ret_metrics.pop('aAcc', None)
|
||||||
|
ret_metrics_class = OrderedDict({
|
||||||
|
ret_metric: np.round(ret_metric_value * 100, 2)
|
||||||
|
for ret_metric, ret_metric_value in ret_metrics.items()
|
||||||
|
})
|
||||||
|
ret_metrics_class.update({'Class': class_names})
|
||||||
|
ret_metrics_class.move_to_end('Class', last=False)
|
||||||
|
|
||||||
|
# for logger
|
||||||
|
class_table_data = PrettyTable()
|
||||||
|
for key, val in ret_metrics_class.items():
|
||||||
|
class_table_data.add_column(key, val)
|
||||||
|
|
||||||
|
summary_table_data = PrettyTable()
|
||||||
|
for key, val in ret_metrics_summary.items():
|
||||||
|
if key == 'aAcc':
|
||||||
|
summary_table_data.add_column(key, [val])
|
||||||
|
else:
|
||||||
|
summary_table_data.add_column('m' + key, [val])
|
||||||
|
|
||||||
|
print_log('per class results:', logger)
|
||||||
|
print_log('\n' + class_table_data.get_string(), logger=logger)
|
||||||
|
print_log('Summary:', logger)
|
||||||
|
print_log('\n' + summary_table_data.get_string(), logger=logger)
|
||||||
|
|
||||||
|
# each metric dict
|
||||||
|
for key, value in ret_metrics_summary.items():
|
||||||
|
if key == 'aAcc':
|
||||||
|
eval_results[key] = value / 100.0
|
||||||
|
else:
|
||||||
|
eval_results['m' + key] = value / 100.0
|
||||||
|
|
||||||
|
ret_metrics_class.pop('Class', None)
|
||||||
|
for key, value in ret_metrics_class.items():
|
||||||
|
eval_results.update({
|
||||||
|
key + '.' + str(name): value[idx] / 100.0
|
||||||
|
for idx, name in enumerate(class_names)
|
||||||
|
})
|
||||||
|
|
||||||
if mmcv.is_list_of(results, str):
|
if mmcv.is_list_of(results, str):
|
||||||
for file_name in results:
|
for file_name in results:
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
matplotlib
|
matplotlib
|
||||||
numpy
|
numpy
|
||||||
terminaltables
|
prettytable
|
||||||
|
@ -8,6 +8,6 @@ line_length = 79
|
|||||||
multi_line_output = 0
|
multi_line_output = 0
|
||||||
known_standard_library = setuptools
|
known_standard_library = setuptools
|
||||||
known_first_party = mmseg
|
known_first_party = mmseg
|
||||||
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,seaborn,terminaltables,torch
|
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch
|
||||||
no_lines_before = STDLIB,LOCALFOLDER
|
no_lines_before = STDLIB,LOCALFOLDER
|
||||||
default_section = THIRDPARTY
|
default_section = THIRDPARTY
|
||||||
|
@ -159,7 +159,7 @@ def test_custom_dataset():
|
|||||||
for gt_seg_map in gt_seg_maps:
|
for gt_seg_map in gt_seg_maps:
|
||||||
h, w = gt_seg_map.shape
|
h, w = gt_seg_map.shape
|
||||||
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
|
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
|
||||||
eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU')
|
eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
|
||||||
assert isinstance(eval_results, dict)
|
assert isinstance(eval_results, dict)
|
||||||
assert 'mIoU' in eval_results
|
assert 'mIoU' in eval_results
|
||||||
assert 'mAcc' in eval_results
|
assert 'mAcc' in eval_results
|
||||||
@ -193,13 +193,23 @@ def test_custom_dataset():
|
|||||||
assert 'mAcc' in eval_results
|
assert 'mAcc' in eval_results
|
||||||
assert 'aAcc' in eval_results
|
assert 'aAcc' in eval_results
|
||||||
|
|
||||||
|
eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore')
|
||||||
|
assert isinstance(eval_results, dict)
|
||||||
|
assert 'mRecall' in eval_results
|
||||||
|
assert 'mPrecision' in eval_results
|
||||||
|
assert 'mFscore' in eval_results
|
||||||
|
assert 'aAcc' in eval_results
|
||||||
|
|
||||||
eval_results = train_dataset.evaluate(
|
eval_results = train_dataset.evaluate(
|
||||||
pseudo_results, metric=['mIoU', 'mDice'])
|
pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
|
||||||
assert isinstance(eval_results, dict)
|
assert isinstance(eval_results, dict)
|
||||||
assert 'mIoU' in eval_results
|
assert 'mIoU' in eval_results
|
||||||
assert 'mDice' in eval_results
|
assert 'mDice' in eval_results
|
||||||
assert 'mAcc' in eval_results
|
assert 'mAcc' in eval_results
|
||||||
assert 'aAcc' in eval_results
|
assert 'aAcc' in eval_results
|
||||||
|
assert 'mFscore' in eval_results
|
||||||
|
assert 'mPrecision' in eval_results
|
||||||
|
assert 'mRecall' in eval_results
|
||||||
|
|
||||||
|
|
||||||
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mmseg.core.evaluation import eval_metrics, mean_dice, mean_iou
|
from mmseg.core.evaluation import (eval_metrics, mean_dice, mean_fscore,
|
||||||
|
mean_iou)
|
||||||
|
from mmseg.core.evaluation.metrics import f_score
|
||||||
|
|
||||||
|
|
||||||
def get_confusion_matrix(pred_label, label, num_classes, ignore_index):
|
def get_confusion_matrix(pred_label, label, num_classes, ignore_index):
|
||||||
@ -58,6 +60,28 @@ def legacy_mean_dice(results, gt_seg_maps, num_classes, ignore_index):
|
|||||||
return all_acc, acc, dice
|
return all_acc, acc, dice
|
||||||
|
|
||||||
|
|
||||||
|
# This func is deprecated since it's not memory efficient
|
||||||
|
def legacy_mean_fscore(results,
|
||||||
|
gt_seg_maps,
|
||||||
|
num_classes,
|
||||||
|
ignore_index,
|
||||||
|
beta=1):
|
||||||
|
num_imgs = len(results)
|
||||||
|
assert len(gt_seg_maps) == num_imgs
|
||||||
|
total_mat = np.zeros((num_classes, num_classes), dtype=np.float)
|
||||||
|
for i in range(num_imgs):
|
||||||
|
mat = get_confusion_matrix(
|
||||||
|
results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index)
|
||||||
|
total_mat += mat
|
||||||
|
all_acc = np.diag(total_mat).sum() / total_mat.sum()
|
||||||
|
recall = np.diag(total_mat) / total_mat.sum(axis=1)
|
||||||
|
precision = np.diag(total_mat) / total_mat.sum(axis=0)
|
||||||
|
fv = np.vectorize(f_score)
|
||||||
|
fscore = fv(precision, recall, beta=beta)
|
||||||
|
|
||||||
|
return all_acc, recall, precision, fscore
|
||||||
|
|
||||||
|
|
||||||
def test_metrics():
|
def test_metrics():
|
||||||
pred_size = (10, 30, 30)
|
pred_size = (10, 30, 30)
|
||||||
num_classes = 19
|
num_classes = 19
|
||||||
@ -69,63 +93,113 @@ def test_metrics():
|
|||||||
label[:, 2, 5:10] = ignore_index
|
label[:, 2, 5:10] = ignore_index
|
||||||
|
|
||||||
# Test the correctness of the implementation of mIoU calculation.
|
# Test the correctness of the implementation of mIoU calculation.
|
||||||
all_acc, acc, iou = eval_metrics(
|
ret_metrics = eval_metrics(
|
||||||
results, label, num_classes, ignore_index, metrics='mIoU')
|
results, label, num_classes, ignore_index, metrics='mIoU')
|
||||||
|
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||||
|
'IoU']
|
||||||
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
|
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
|
||||||
ignore_index)
|
ignore_index)
|
||||||
assert all_acc == all_acc_l
|
assert all_acc == all_acc_l
|
||||||
assert np.allclose(acc, acc_l)
|
assert np.allclose(acc, acc_l)
|
||||||
assert np.allclose(iou, iou_l)
|
assert np.allclose(iou, iou_l)
|
||||||
# Test the correctness of the implementation of mDice calculation.
|
# Test the correctness of the implementation of mDice calculation.
|
||||||
all_acc, acc, dice = eval_metrics(
|
ret_metrics = eval_metrics(
|
||||||
results, label, num_classes, ignore_index, metrics='mDice')
|
results, label, num_classes, ignore_index, metrics='mDice')
|
||||||
|
all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||||
|
'Dice']
|
||||||
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
|
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
|
||||||
ignore_index)
|
ignore_index)
|
||||||
assert all_acc == all_acc_l
|
assert all_acc == all_acc_l
|
||||||
assert np.allclose(acc, acc_l)
|
assert np.allclose(acc, acc_l)
|
||||||
assert np.allclose(dice, dice_l)
|
assert np.allclose(dice, dice_l)
|
||||||
|
# Test the correctness of the implementation of mDice calculation.
|
||||||
|
ret_metrics = eval_metrics(
|
||||||
|
results, label, num_classes, ignore_index, metrics='mFscore')
|
||||||
|
all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
|
||||||
|
'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
|
||||||
|
all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore(
|
||||||
|
results, label, num_classes, ignore_index)
|
||||||
|
assert all_acc == all_acc_l
|
||||||
|
assert np.allclose(recall, recall_l)
|
||||||
|
assert np.allclose(precision, precision_l)
|
||||||
|
assert np.allclose(fscore, fscore_l)
|
||||||
# Test the correctness of the implementation of joint calculation.
|
# Test the correctness of the implementation of joint calculation.
|
||||||
all_acc, acc, iou, dice = eval_metrics(
|
ret_metrics = eval_metrics(
|
||||||
results, label, num_classes, ignore_index, metrics=['mIoU', 'mDice'])
|
results,
|
||||||
|
label,
|
||||||
|
num_classes,
|
||||||
|
ignore_index,
|
||||||
|
metrics=['mIoU', 'mDice', 'mFscore'])
|
||||||
|
all_acc, acc, iou, dice, precision, recall, fscore = ret_metrics[
|
||||||
|
'aAcc'], ret_metrics['Acc'], ret_metrics['IoU'], ret_metrics[
|
||||||
|
'Dice'], ret_metrics['Precision'], ret_metrics[
|
||||||
|
'Recall'], ret_metrics['Fscore']
|
||||||
assert all_acc == all_acc_l
|
assert all_acc == all_acc_l
|
||||||
assert np.allclose(acc, acc_l)
|
assert np.allclose(acc, acc_l)
|
||||||
assert np.allclose(iou, iou_l)
|
assert np.allclose(iou, iou_l)
|
||||||
assert np.allclose(dice, dice_l)
|
assert np.allclose(dice, dice_l)
|
||||||
|
assert np.allclose(precision, precision_l)
|
||||||
|
assert np.allclose(recall, recall_l)
|
||||||
|
assert np.allclose(fscore, fscore_l)
|
||||||
|
|
||||||
# Test the correctness of calculation when arg: num_classes is larger
|
# Test the correctness of calculation when arg: num_classes is larger
|
||||||
# than the maximum value of input maps.
|
# than the maximum value of input maps.
|
||||||
results = np.random.randint(0, 5, size=pred_size)
|
results = np.random.randint(0, 5, size=pred_size)
|
||||||
label = np.random.randint(0, 4, size=pred_size)
|
label = np.random.randint(0, 4, size=pred_size)
|
||||||
all_acc, acc, iou = eval_metrics(
|
ret_metrics = eval_metrics(
|
||||||
results,
|
results,
|
||||||
label,
|
label,
|
||||||
num_classes,
|
num_classes,
|
||||||
ignore_index=255,
|
ignore_index=255,
|
||||||
metrics='mIoU',
|
metrics='mIoU',
|
||||||
nan_to_num=-1)
|
nan_to_num=-1)
|
||||||
|
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||||
|
'IoU']
|
||||||
assert acc[-1] == -1
|
assert acc[-1] == -1
|
||||||
assert iou[-1] == -1
|
assert iou[-1] == -1
|
||||||
|
|
||||||
all_acc, acc, dice = eval_metrics(
|
ret_metrics = eval_metrics(
|
||||||
results,
|
results,
|
||||||
label,
|
label,
|
||||||
num_classes,
|
num_classes,
|
||||||
ignore_index=255,
|
ignore_index=255,
|
||||||
metrics='mDice',
|
metrics='mDice',
|
||||||
nan_to_num=-1)
|
nan_to_num=-1)
|
||||||
|
all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||||
|
'Dice']
|
||||||
assert acc[-1] == -1
|
assert acc[-1] == -1
|
||||||
assert dice[-1] == -1
|
assert dice[-1] == -1
|
||||||
|
|
||||||
all_acc, acc, dice, iou = eval_metrics(
|
ret_metrics = eval_metrics(
|
||||||
results,
|
results,
|
||||||
label,
|
label,
|
||||||
num_classes,
|
num_classes,
|
||||||
ignore_index=255,
|
ignore_index=255,
|
||||||
metrics=['mDice', 'mIoU'],
|
metrics='mFscore',
|
||||||
nan_to_num=-1)
|
nan_to_num=-1)
|
||||||
|
all_acc, precision, recall, fscore = ret_metrics['aAcc'], ret_metrics[
|
||||||
|
'Precision'], ret_metrics['Recall'], ret_metrics['Fscore']
|
||||||
|
assert precision[-1] == -1
|
||||||
|
assert recall[-1] == -1
|
||||||
|
assert fscore[-1] == -1
|
||||||
|
|
||||||
|
ret_metrics = eval_metrics(
|
||||||
|
results,
|
||||||
|
label,
|
||||||
|
num_classes,
|
||||||
|
ignore_index=255,
|
||||||
|
metrics=['mDice', 'mIoU', 'mFscore'],
|
||||||
|
nan_to_num=-1)
|
||||||
|
all_acc, acc, iou, dice, precision, recall, fscore = ret_metrics[
|
||||||
|
'aAcc'], ret_metrics['Acc'], ret_metrics['IoU'], ret_metrics[
|
||||||
|
'Dice'], ret_metrics['Precision'], ret_metrics[
|
||||||
|
'Recall'], ret_metrics['Fscore']
|
||||||
assert acc[-1] == -1
|
assert acc[-1] == -1
|
||||||
assert dice[-1] == -1
|
assert dice[-1] == -1
|
||||||
assert iou[-1] == -1
|
assert iou[-1] == -1
|
||||||
|
assert precision[-1] == -1
|
||||||
|
assert recall[-1] == -1
|
||||||
|
assert fscore[-1] == -1
|
||||||
|
|
||||||
# Test the bug which is caused by torch.histc.
|
# Test the bug which is caused by torch.histc.
|
||||||
# torch.histc: https://pytorch.org/docs/stable/generated/torch.histc.html
|
# torch.histc: https://pytorch.org/docs/stable/generated/torch.histc.html
|
||||||
@ -134,8 +208,10 @@ def test_metrics():
|
|||||||
results = np.array([np.repeat(31, 59)])
|
results = np.array([np.repeat(31, 59)])
|
||||||
label = np.array([np.arange(59)])
|
label = np.array([np.arange(59)])
|
||||||
num_classes = 59
|
num_classes = 59
|
||||||
all_acc, acc, iou = eval_metrics(
|
ret_metrics = eval_metrics(
|
||||||
results, label, num_classes, ignore_index=255, metrics='mIoU')
|
results, label, num_classes, ignore_index=255, metrics='mIoU')
|
||||||
|
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||||
|
'IoU']
|
||||||
assert not np.any(np.isnan(iou))
|
assert not np.any(np.isnan(iou))
|
||||||
|
|
||||||
|
|
||||||
@ -146,7 +222,9 @@ def test_mean_iou():
|
|||||||
results = np.random.randint(0, num_classes, size=pred_size)
|
results = np.random.randint(0, num_classes, size=pred_size)
|
||||||
label = np.random.randint(0, num_classes, size=pred_size)
|
label = np.random.randint(0, num_classes, size=pred_size)
|
||||||
label[:, 2, 5:10] = ignore_index
|
label[:, 2, 5:10] = ignore_index
|
||||||
all_acc, acc, iou = mean_iou(results, label, num_classes, ignore_index)
|
ret_metrics = mean_iou(results, label, num_classes, ignore_index)
|
||||||
|
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||||
|
'IoU']
|
||||||
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
|
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
|
||||||
ignore_index)
|
ignore_index)
|
||||||
assert all_acc == all_acc_l
|
assert all_acc == all_acc_l
|
||||||
@ -155,10 +233,12 @@ def test_mean_iou():
|
|||||||
|
|
||||||
results = np.random.randint(0, 5, size=pred_size)
|
results = np.random.randint(0, 5, size=pred_size)
|
||||||
label = np.random.randint(0, 4, size=pred_size)
|
label = np.random.randint(0, 4, size=pred_size)
|
||||||
all_acc, acc, iou = mean_iou(
|
ret_metrics = mean_iou(
|
||||||
results, label, num_classes, ignore_index=255, nan_to_num=-1)
|
results, label, num_classes, ignore_index=255, nan_to_num=-1)
|
||||||
|
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||||
|
'IoU']
|
||||||
|
assert acc[-1] == -1
|
||||||
assert acc[-1] == -1
|
assert acc[-1] == -1
|
||||||
assert iou[-1] == -1
|
|
||||||
|
|
||||||
|
|
||||||
def test_mean_dice():
|
def test_mean_dice():
|
||||||
@ -168,19 +248,62 @@ def test_mean_dice():
|
|||||||
results = np.random.randint(0, num_classes, size=pred_size)
|
results = np.random.randint(0, num_classes, size=pred_size)
|
||||||
label = np.random.randint(0, num_classes, size=pred_size)
|
label = np.random.randint(0, num_classes, size=pred_size)
|
||||||
label[:, 2, 5:10] = ignore_index
|
label[:, 2, 5:10] = ignore_index
|
||||||
all_acc, acc, iou = mean_dice(results, label, num_classes, ignore_index)
|
ret_metrics = mean_dice(results, label, num_classes, ignore_index)
|
||||||
all_acc_l, acc_l, iou_l = legacy_mean_dice(results, label, num_classes,
|
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||||
ignore_index)
|
'Dice']
|
||||||
|
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
|
||||||
|
ignore_index)
|
||||||
assert all_acc == all_acc_l
|
assert all_acc == all_acc_l
|
||||||
assert np.allclose(acc, acc_l)
|
assert np.allclose(acc, acc_l)
|
||||||
assert np.allclose(iou, iou_l)
|
assert np.allclose(iou, dice_l)
|
||||||
|
|
||||||
results = np.random.randint(0, 5, size=pred_size)
|
results = np.random.randint(0, 5, size=pred_size)
|
||||||
label = np.random.randint(0, 4, size=pred_size)
|
label = np.random.randint(0, 4, size=pred_size)
|
||||||
all_acc, acc, iou = mean_dice(
|
ret_metrics = mean_dice(
|
||||||
results, label, num_classes, ignore_index=255, nan_to_num=-1)
|
results, label, num_classes, ignore_index=255, nan_to_num=-1)
|
||||||
|
all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||||
|
'Dice']
|
||||||
assert acc[-1] == -1
|
assert acc[-1] == -1
|
||||||
assert iou[-1] == -1
|
assert dice[-1] == -1
|
||||||
|
|
||||||
|
|
||||||
|
def test_mean_fscore():
|
||||||
|
pred_size = (10, 30, 30)
|
||||||
|
num_classes = 19
|
||||||
|
ignore_index = 255
|
||||||
|
results = np.random.randint(0, num_classes, size=pred_size)
|
||||||
|
label = np.random.randint(0, num_classes, size=pred_size)
|
||||||
|
label[:, 2, 5:10] = ignore_index
|
||||||
|
ret_metrics = mean_fscore(results, label, num_classes, ignore_index)
|
||||||
|
all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
|
||||||
|
'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
|
||||||
|
all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore(
|
||||||
|
results, label, num_classes, ignore_index)
|
||||||
|
assert all_acc == all_acc_l
|
||||||
|
assert np.allclose(recall, recall_l)
|
||||||
|
assert np.allclose(precision, precision_l)
|
||||||
|
assert np.allclose(fscore, fscore_l)
|
||||||
|
|
||||||
|
ret_metrics = mean_fscore(
|
||||||
|
results, label, num_classes, ignore_index, beta=2)
|
||||||
|
all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
|
||||||
|
'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
|
||||||
|
all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore(
|
||||||
|
results, label, num_classes, ignore_index, beta=2)
|
||||||
|
assert all_acc == all_acc_l
|
||||||
|
assert np.allclose(recall, recall_l)
|
||||||
|
assert np.allclose(precision, precision_l)
|
||||||
|
assert np.allclose(fscore, fscore_l)
|
||||||
|
|
||||||
|
results = np.random.randint(0, 5, size=pred_size)
|
||||||
|
label = np.random.randint(0, 4, size=pred_size)
|
||||||
|
ret_metrics = mean_fscore(
|
||||||
|
results, label, num_classes, ignore_index=255, nan_to_num=-1)
|
||||||
|
all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
|
||||||
|
'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
|
||||||
|
assert recall[-1] == -1
|
||||||
|
assert precision[-1] == -1
|
||||||
|
assert fscore[-1] == -1
|
||||||
|
|
||||||
|
|
||||||
def test_filename_inputs():
|
def test_filename_inputs():
|
||||||
@ -211,13 +334,14 @@ def test_filename_inputs():
|
|||||||
result_files = save_arr(results, 'pred', False, temp_dir)
|
result_files = save_arr(results, 'pred', False, temp_dir)
|
||||||
label_files = save_arr(labels, 'label', True, temp_dir)
|
label_files = save_arr(labels, 'label', True, temp_dir)
|
||||||
|
|
||||||
all_acc, acc, iou = eval_metrics(
|
ret_metrics = eval_metrics(
|
||||||
result_files,
|
result_files,
|
||||||
label_files,
|
label_files,
|
||||||
num_classes,
|
num_classes,
|
||||||
ignore_index,
|
ignore_index,
|
||||||
metrics='mIoU')
|
metrics='mIoU')
|
||||||
|
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics[
|
||||||
|
'Acc'], ret_metrics['IoU']
|
||||||
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, labels, num_classes,
|
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, labels, num_classes,
|
||||||
ignore_index)
|
ignore_index)
|
||||||
assert all_acc == all_acc_l
|
assert all_acc == all_acc_l
|
||||||
|
Loading…
x
Reference in New Issue
Block a user