mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
add dice evaluation metric (#225)
* add dice evaluation metric * add dice evaluation metric * add dice evaluation metric * support 2 metrics * support 2 metrics * support 2 metrics * support 2 metrics * fix docstring * use np.round once for all
This commit is contained in:
parent
90e8e38e84
commit
993be2523b
@ -1,7 +1,8 @@
|
|||||||
from .class_names import get_classes, get_palette
|
from .class_names import get_classes, get_palette
|
||||||
from .eval_hooks import DistEvalHook, EvalHook
|
from .eval_hooks import DistEvalHook, EvalHook
|
||||||
from .mean_iou import mean_iou
|
from .metrics import eval_metrics, mean_dice, mean_iou
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'EvalHook', 'DistEvalHook', 'mean_iou', 'get_classes', 'get_palette'
|
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics',
|
||||||
|
'get_classes', 'get_palette'
|
||||||
]
|
]
|
||||||
|
@ -1,74 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def intersect_and_union(pred_label, label, num_classes, ignore_index):
|
|
||||||
"""Calculate intersection and Union.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pred_label (ndarray): Prediction segmentation map
|
|
||||||
label (ndarray): Ground truth segmentation map
|
|
||||||
num_classes (int): Number of categories
|
|
||||||
ignore_index (int): Index that will be ignored in evaluation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ndarray: The intersection of prediction and ground truth histogram
|
|
||||||
on all classes
|
|
||||||
ndarray: The union of prediction and ground truth histogram on all
|
|
||||||
classes
|
|
||||||
ndarray: The prediction histogram on all classes.
|
|
||||||
ndarray: The ground truth histogram on all classes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
mask = (label != ignore_index)
|
|
||||||
pred_label = pred_label[mask]
|
|
||||||
label = label[mask]
|
|
||||||
|
|
||||||
intersect = pred_label[pred_label == label]
|
|
||||||
area_intersect, _ = np.histogram(
|
|
||||||
intersect, bins=np.arange(num_classes + 1))
|
|
||||||
area_pred_label, _ = np.histogram(
|
|
||||||
pred_label, bins=np.arange(num_classes + 1))
|
|
||||||
area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1))
|
|
||||||
area_union = area_pred_label + area_label - area_intersect
|
|
||||||
|
|
||||||
return area_intersect, area_union, area_pred_label, area_label
|
|
||||||
|
|
||||||
|
|
||||||
def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None):
|
|
||||||
"""Calculate Intersection and Union (IoU)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results (list[ndarray]): List of prediction segmentation maps
|
|
||||||
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
|
|
||||||
num_classes (int): Number of categories
|
|
||||||
ignore_index (int): Index that will be ignored in evaluation.
|
|
||||||
nan_to_num (int, optional): If specified, NaN values will be replaced
|
|
||||||
by the numbers defined by the user. Default: None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: Overall accuracy on all images.
|
|
||||||
ndarray: Per category accuracy, shape (num_classes, )
|
|
||||||
ndarray: Per category IoU, shape (num_classes, )
|
|
||||||
"""
|
|
||||||
|
|
||||||
num_imgs = len(results)
|
|
||||||
assert len(gt_seg_maps) == num_imgs
|
|
||||||
total_area_intersect = np.zeros((num_classes, ), dtype=np.float)
|
|
||||||
total_area_union = np.zeros((num_classes, ), dtype=np.float)
|
|
||||||
total_area_pred_label = np.zeros((num_classes, ), dtype=np.float)
|
|
||||||
total_area_label = np.zeros((num_classes, ), dtype=np.float)
|
|
||||||
for i in range(num_imgs):
|
|
||||||
area_intersect, area_union, area_pred_label, area_label = \
|
|
||||||
intersect_and_union(results[i], gt_seg_maps[i], num_classes,
|
|
||||||
ignore_index=ignore_index)
|
|
||||||
total_area_intersect += area_intersect
|
|
||||||
total_area_union += area_union
|
|
||||||
total_area_pred_label += area_pred_label
|
|
||||||
total_area_label += area_label
|
|
||||||
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
|
||||||
acc = total_area_intersect / total_area_label
|
|
||||||
iou = total_area_intersect / total_area_union
|
|
||||||
if nan_to_num is not None:
|
|
||||||
return all_acc, np.nan_to_num(acc, nan=nan_to_num), \
|
|
||||||
np.nan_to_num(iou, nan=nan_to_num)
|
|
||||||
return all_acc, acc, iou
|
|
176
mmseg/core/evaluation/metrics.py
Normal file
176
mmseg/core/evaluation/metrics.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def intersect_and_union(pred_label, label, num_classes, ignore_index):
|
||||||
|
"""Calculate intersection and Union.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred_label (ndarray): Prediction segmentation map
|
||||||
|
label (ndarray): Ground truth segmentation map
|
||||||
|
num_classes (int): Number of categories
|
||||||
|
ignore_index (int): Index that will be ignored in evaluation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ndarray: The intersection of prediction and ground truth histogram
|
||||||
|
on all classes
|
||||||
|
ndarray: The union of prediction and ground truth histogram on all
|
||||||
|
classes
|
||||||
|
ndarray: The prediction histogram on all classes.
|
||||||
|
ndarray: The ground truth histogram on all classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
mask = (label != ignore_index)
|
||||||
|
pred_label = pred_label[mask]
|
||||||
|
label = label[mask]
|
||||||
|
|
||||||
|
intersect = pred_label[pred_label == label]
|
||||||
|
area_intersect, _ = np.histogram(
|
||||||
|
intersect, bins=np.arange(num_classes + 1))
|
||||||
|
area_pred_label, _ = np.histogram(
|
||||||
|
pred_label, bins=np.arange(num_classes + 1))
|
||||||
|
area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1))
|
||||||
|
area_union = area_pred_label + area_label - area_intersect
|
||||||
|
|
||||||
|
return area_intersect, area_union, area_pred_label, area_label
|
||||||
|
|
||||||
|
|
||||||
|
def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index):
|
||||||
|
"""Calculate Total Intersection and Union.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (list[ndarray]): List of prediction segmentation maps
|
||||||
|
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
|
||||||
|
num_classes (int): Number of categories
|
||||||
|
ignore_index (int): Index that will be ignored in evaluation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ndarray: The intersection of prediction and ground truth histogram
|
||||||
|
on all classes
|
||||||
|
ndarray: The union of prediction and ground truth histogram on all
|
||||||
|
classes
|
||||||
|
ndarray: The prediction histogram on all classes.
|
||||||
|
ndarray: The ground truth histogram on all classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_imgs = len(results)
|
||||||
|
assert len(gt_seg_maps) == num_imgs
|
||||||
|
total_area_intersect = np.zeros((num_classes, ), dtype=np.float)
|
||||||
|
total_area_union = np.zeros((num_classes, ), dtype=np.float)
|
||||||
|
total_area_pred_label = np.zeros((num_classes, ), dtype=np.float)
|
||||||
|
total_area_label = np.zeros((num_classes, ), dtype=np.float)
|
||||||
|
for i in range(num_imgs):
|
||||||
|
area_intersect, area_union, area_pred_label, area_label = \
|
||||||
|
intersect_and_union(results[i], gt_seg_maps[i], num_classes,
|
||||||
|
ignore_index=ignore_index)
|
||||||
|
total_area_intersect += area_intersect
|
||||||
|
total_area_union += area_union
|
||||||
|
total_area_pred_label += area_pred_label
|
||||||
|
total_area_label += area_label
|
||||||
|
return total_area_intersect, total_area_union, \
|
||||||
|
total_area_pred_label, total_area_label
|
||||||
|
|
||||||
|
|
||||||
|
def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None):
|
||||||
|
"""Calculate Mean Intersection and Union (mIoU)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (list[ndarray]): List of prediction segmentation maps
|
||||||
|
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
|
||||||
|
num_classes (int): Number of categories
|
||||||
|
ignore_index (int): Index that will be ignored in evaluation.
|
||||||
|
nan_to_num (int, optional): If specified, NaN values will be replaced
|
||||||
|
by the numbers defined by the user. Default: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Overall accuracy on all images.
|
||||||
|
ndarray: Per category accuracy, shape (num_classes, )
|
||||||
|
ndarray: Per category IoU, shape (num_classes, )
|
||||||
|
"""
|
||||||
|
|
||||||
|
all_acc, acc, iou = eval_metrics(
|
||||||
|
results=results,
|
||||||
|
gt_seg_maps=gt_seg_maps,
|
||||||
|
num_classes=num_classes,
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
metrics=['mIoU'],
|
||||||
|
nan_to_num=nan_to_num)
|
||||||
|
return all_acc, acc, iou
|
||||||
|
|
||||||
|
|
||||||
|
def mean_dice(results,
|
||||||
|
gt_seg_maps,
|
||||||
|
num_classes,
|
||||||
|
ignore_index,
|
||||||
|
nan_to_num=None):
|
||||||
|
"""Calculate Mean Dice (mDice)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (list[ndarray]): List of prediction segmentation maps
|
||||||
|
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
|
||||||
|
num_classes (int): Number of categories
|
||||||
|
ignore_index (int): Index that will be ignored in evaluation.
|
||||||
|
nan_to_num (int, optional): If specified, NaN values will be replaced
|
||||||
|
by the numbers defined by the user. Default: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Overall accuracy on all images.
|
||||||
|
ndarray: Per category accuracy, shape (num_classes, )
|
||||||
|
ndarray: Per category dice, shape (num_classes, )
|
||||||
|
"""
|
||||||
|
|
||||||
|
all_acc, acc, dice = eval_metrics(
|
||||||
|
results=results,
|
||||||
|
gt_seg_maps=gt_seg_maps,
|
||||||
|
num_classes=num_classes,
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
metrics=['mDice'],
|
||||||
|
nan_to_num=nan_to_num)
|
||||||
|
return all_acc, acc, dice
|
||||||
|
|
||||||
|
|
||||||
|
def eval_metrics(results,
|
||||||
|
gt_seg_maps,
|
||||||
|
num_classes,
|
||||||
|
ignore_index,
|
||||||
|
metrics=['mIoU'],
|
||||||
|
nan_to_num=None):
|
||||||
|
"""Calculate evaluation metrics
|
||||||
|
Args:
|
||||||
|
results (list[ndarray]): List of prediction segmentation maps
|
||||||
|
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
|
||||||
|
num_classes (int): Number of categories
|
||||||
|
ignore_index (int): Index that will be ignored in evaluation.
|
||||||
|
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
|
||||||
|
nan_to_num (int, optional): If specified, NaN values will be replaced
|
||||||
|
by the numbers defined by the user. Default: None.
|
||||||
|
Returns:
|
||||||
|
float: Overall accuracy on all images.
|
||||||
|
ndarray: Per category accuracy, shape (num_classes, )
|
||||||
|
ndarray: Per category evalution metrics, shape (num_classes, )
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(metrics, str):
|
||||||
|
metrics = [metrics]
|
||||||
|
allowed_metrics = ['mIoU', 'mDice']
|
||||||
|
if not set(metrics).issubset(set(allowed_metrics)):
|
||||||
|
raise KeyError('metrics {} is not supported'.format(metrics))
|
||||||
|
total_area_intersect, total_area_union, total_area_pred_label, \
|
||||||
|
total_area_label = total_intersect_and_union(results, gt_seg_maps,
|
||||||
|
num_classes,
|
||||||
|
ignore_index=ignore_index)
|
||||||
|
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
||||||
|
acc = total_area_intersect / total_area_label
|
||||||
|
ret_metrics = [all_acc, acc]
|
||||||
|
for metric in metrics:
|
||||||
|
if metric == 'mIoU':
|
||||||
|
iou = total_area_intersect / total_area_union
|
||||||
|
ret_metrics.append(iou)
|
||||||
|
elif metric == 'mDice':
|
||||||
|
dice = 2 * total_area_intersect / (
|
||||||
|
total_area_pred_label + total_area_label)
|
||||||
|
ret_metrics.append(dice)
|
||||||
|
if nan_to_num is not None:
|
||||||
|
ret_metrics = [
|
||||||
|
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics
|
||||||
|
]
|
||||||
|
return ret_metrics
|
@ -4,9 +4,10 @@ from functools import reduce
|
|||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mmcv.utils import print_log
|
from mmcv.utils import print_log
|
||||||
|
from terminaltables import AsciiTable
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from mmseg.core import mean_iou
|
from mmseg.core import eval_metrics
|
||||||
from mmseg.utils import get_root_logger
|
from mmseg.utils import get_root_logger
|
||||||
from .builder import DATASETS
|
from .builder import DATASETS
|
||||||
from .pipelines import Compose
|
from .pipelines import Compose
|
||||||
@ -14,9 +15,8 @@ from .pipelines import Compose
|
|||||||
|
|
||||||
@DATASETS.register_module()
|
@DATASETS.register_module()
|
||||||
class CustomDataset(Dataset):
|
class CustomDataset(Dataset):
|
||||||
"""Custom dataset for semantic segmentation.
|
"""Custom dataset for semantic segmentation. An example of file structure
|
||||||
|
is as followed.
|
||||||
An example of file structure is as followed.
|
|
||||||
|
|
||||||
.. code-block:: none
|
.. code-block:: none
|
||||||
|
|
||||||
@ -315,7 +315,8 @@ class CustomDataset(Dataset):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
results (list): Testing results of the dataset.
|
results (list): Testing results of the dataset.
|
||||||
metric (str | list[str]): Metrics to be evaluated.
|
metric (str | list[str]): Metrics to be evaluated. 'mIoU' and
|
||||||
|
'mDice' are supported.
|
||||||
logger (logging.Logger | None | str): Logger used for printing
|
logger (logging.Logger | None | str): Logger used for printing
|
||||||
related information during evaluation. Default: None.
|
related information during evaluation. Default: None.
|
||||||
|
|
||||||
@ -323,13 +324,11 @@ class CustomDataset(Dataset):
|
|||||||
dict[str, float]: Default metrics.
|
dict[str, float]: Default metrics.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(metric, str):
|
if isinstance(metric, str):
|
||||||
assert len(metric) == 1
|
metric = [metric]
|
||||||
metric = metric[0]
|
allowed_metrics = ['mIoU', 'mDice']
|
||||||
allowed_metrics = ['mIoU']
|
if not set(metric).issubset(set(allowed_metrics)):
|
||||||
if metric not in allowed_metrics:
|
|
||||||
raise KeyError('metric {} is not supported'.format(metric))
|
raise KeyError('metric {} is not supported'.format(metric))
|
||||||
|
|
||||||
eval_results = {}
|
eval_results = {}
|
||||||
gt_seg_maps = self.get_gt_seg_maps()
|
gt_seg_maps = self.get_gt_seg_maps()
|
||||||
if self.CLASSES is None:
|
if self.CLASSES is None:
|
||||||
@ -337,35 +336,42 @@ class CustomDataset(Dataset):
|
|||||||
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
|
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
|
||||||
else:
|
else:
|
||||||
num_classes = len(self.CLASSES)
|
num_classes = len(self.CLASSES)
|
||||||
|
ret_metrics = eval_metrics(
|
||||||
all_acc, acc, iou = mean_iou(
|
results,
|
||||||
results, gt_seg_maps, num_classes, ignore_index=self.ignore_index)
|
gt_seg_maps,
|
||||||
summary_str = ''
|
num_classes,
|
||||||
summary_str += 'per class results:\n'
|
ignore_index=self.ignore_index,
|
||||||
|
metrics=metric)
|
||||||
line_format = '{:<15} {:>10} {:>10}\n'
|
class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']]
|
||||||
summary_str += line_format.format('Class', 'IoU', 'Acc')
|
|
||||||
if self.CLASSES is None:
|
if self.CLASSES is None:
|
||||||
class_names = tuple(range(num_classes))
|
class_names = tuple(range(num_classes))
|
||||||
else:
|
else:
|
||||||
class_names = self.CLASSES
|
class_names = self.CLASSES
|
||||||
|
ret_metrics_round = [
|
||||||
|
np.round(ret_metric * 100, 2) for ret_metric in ret_metrics
|
||||||
|
]
|
||||||
for i in range(num_classes):
|
for i in range(num_classes):
|
||||||
iou_str = '{:.2f}'.format(iou[i] * 100)
|
class_table_data.append([class_names[i]] +
|
||||||
acc_str = '{:.2f}'.format(acc[i] * 100)
|
[m[i] for m in ret_metrics_round[2:]] +
|
||||||
summary_str += line_format.format(class_names[i], iou_str, acc_str)
|
[ret_metrics_round[1][i]])
|
||||||
summary_str += 'Summary:\n'
|
summary_table_data = [['Scope'] +
|
||||||
line_format = '{:<15} {:>10} {:>10} {:>10}\n'
|
['m' + head
|
||||||
summary_str += line_format.format('Scope', 'mIoU', 'mAcc', 'aAcc')
|
for head in class_table_data[0][1:]] + ['aAcc']]
|
||||||
|
ret_metrics_mean = [
|
||||||
iou_str = '{:.2f}'.format(np.nanmean(iou) * 100)
|
np.round(np.nanmean(ret_metric) * 100, 2)
|
||||||
acc_str = '{:.2f}'.format(np.nanmean(acc) * 100)
|
for ret_metric in ret_metrics
|
||||||
all_acc_str = '{:.2f}'.format(all_acc * 100)
|
]
|
||||||
summary_str += line_format.format('global', iou_str, acc_str,
|
summary_table_data.append(['global'] + ret_metrics_mean[2:] +
|
||||||
all_acc_str)
|
[ret_metrics_mean[1]] +
|
||||||
print_log(summary_str, logger)
|
[ret_metrics_mean[0]])
|
||||||
|
print_log('per class results:', logger)
|
||||||
eval_results['mIoU'] = np.nanmean(iou)
|
table = AsciiTable(class_table_data)
|
||||||
eval_results['mAcc'] = np.nanmean(acc)
|
print_log('\n' + table.table, logger=logger)
|
||||||
eval_results['aAcc'] = all_acc
|
print_log('Summary:', logger)
|
||||||
|
table = AsciiTable(summary_table_data)
|
||||||
|
print_log('\n' + table.table, logger=logger)
|
||||||
|
|
||||||
|
for i in range(1, len(summary_table_data[0])):
|
||||||
|
eval_results[summary_table_data[0]
|
||||||
|
[i]] = summary_table_data[1][i] / 100.0
|
||||||
return eval_results
|
return eval_results
|
||||||
|
@ -1,2 +1,3 @@
|
|||||||
matplotlib
|
matplotlib
|
||||||
numpy
|
numpy
|
||||||
|
terminaltables
|
||||||
|
@ -8,6 +8,6 @@ line_length = 79
|
|||||||
multi_line_output = 0
|
multi_line_output = 0
|
||||||
known_standard_library = setuptools
|
known_standard_library = setuptools
|
||||||
known_first_party = mmseg
|
known_first_party = mmseg
|
||||||
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,torch
|
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,terminaltables,torch
|
||||||
no_lines_before = STDLIB,LOCALFOLDER
|
no_lines_before = STDLIB,LOCALFOLDER
|
||||||
default_section = THIRDPARTY
|
default_section = THIRDPARTY
|
||||||
|
@ -159,17 +159,45 @@ def test_custom_dataset():
|
|||||||
for gt_seg_map in gt_seg_maps:
|
for gt_seg_map in gt_seg_maps:
|
||||||
h, w = gt_seg_map.shape
|
h, w = gt_seg_map.shape
|
||||||
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
|
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
|
||||||
eval_results = train_dataset.evaluate(pseudo_results)
|
eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU')
|
||||||
assert isinstance(eval_results, dict)
|
assert isinstance(eval_results, dict)
|
||||||
assert 'mIoU' in eval_results
|
assert 'mIoU' in eval_results
|
||||||
assert 'mAcc' in eval_results
|
assert 'mAcc' in eval_results
|
||||||
assert 'aAcc' in eval_results
|
assert 'aAcc' in eval_results
|
||||||
|
|
||||||
# evaluation with CLASSES
|
eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
|
||||||
train_dataset.CLASSES = tuple(['a'] * 7)
|
assert isinstance(eval_results, dict)
|
||||||
eval_results = train_dataset.evaluate(pseudo_results)
|
assert 'mDice' in eval_results
|
||||||
|
assert 'mAcc' in eval_results
|
||||||
|
assert 'aAcc' in eval_results
|
||||||
|
|
||||||
|
eval_results = train_dataset.evaluate(
|
||||||
|
pseudo_results, metric=['mDice', 'mIoU'])
|
||||||
assert isinstance(eval_results, dict)
|
assert isinstance(eval_results, dict)
|
||||||
assert 'mIoU' in eval_results
|
assert 'mIoU' in eval_results
|
||||||
|
assert 'mDice' in eval_results
|
||||||
|
assert 'mAcc' in eval_results
|
||||||
|
assert 'aAcc' in eval_results
|
||||||
|
|
||||||
|
# evaluation with CLASSES
|
||||||
|
train_dataset.CLASSES = tuple(['a'] * 7)
|
||||||
|
eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU')
|
||||||
|
assert isinstance(eval_results, dict)
|
||||||
|
assert 'mIoU' in eval_results
|
||||||
|
assert 'mAcc' in eval_results
|
||||||
|
assert 'aAcc' in eval_results
|
||||||
|
|
||||||
|
eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
|
||||||
|
assert isinstance(eval_results, dict)
|
||||||
|
assert 'mDice' in eval_results
|
||||||
|
assert 'mAcc' in eval_results
|
||||||
|
assert 'aAcc' in eval_results
|
||||||
|
|
||||||
|
eval_results = train_dataset.evaluate(
|
||||||
|
pseudo_results, metric=['mIoU', 'mDice'])
|
||||||
|
assert isinstance(eval_results, dict)
|
||||||
|
assert 'mIoU' in eval_results
|
||||||
|
assert 'mDice' in eval_results
|
||||||
assert 'mAcc' in eval_results
|
assert 'mAcc' in eval_results
|
||||||
assert 'aAcc' in eval_results
|
assert 'aAcc' in eval_results
|
||||||
|
|
||||||
|
@ -1,63 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
|
|
||||||
from mmseg.core.evaluation import mean_iou
|
|
||||||
|
|
||||||
|
|
||||||
def get_confusion_matrix(pred_label, label, num_classes, ignore_index):
|
|
||||||
"""Intersection over Union
|
|
||||||
Args:
|
|
||||||
pred_label (np.ndarray): 2D predict map
|
|
||||||
label (np.ndarray): label 2D label map
|
|
||||||
num_classes (int): number of categories
|
|
||||||
ignore_index (int): index ignore in evaluation
|
|
||||||
"""
|
|
||||||
|
|
||||||
mask = (label != ignore_index)
|
|
||||||
pred_label = pred_label[mask]
|
|
||||||
label = label[mask]
|
|
||||||
|
|
||||||
n = num_classes
|
|
||||||
inds = n * label + pred_label
|
|
||||||
|
|
||||||
mat = np.bincount(inds, minlength=n**2).reshape(n, n)
|
|
||||||
|
|
||||||
return mat
|
|
||||||
|
|
||||||
|
|
||||||
# This func is deprecated since it's not memory efficient
|
|
||||||
def legacy_mean_iou(results, gt_seg_maps, num_classes, ignore_index):
|
|
||||||
num_imgs = len(results)
|
|
||||||
assert len(gt_seg_maps) == num_imgs
|
|
||||||
total_mat = np.zeros((num_classes, num_classes), dtype=np.float)
|
|
||||||
for i in range(num_imgs):
|
|
||||||
mat = get_confusion_matrix(
|
|
||||||
results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index)
|
|
||||||
total_mat += mat
|
|
||||||
all_acc = np.diag(total_mat).sum() / total_mat.sum()
|
|
||||||
acc = np.diag(total_mat) / total_mat.sum(axis=1)
|
|
||||||
iou = np.diag(total_mat) / (
|
|
||||||
total_mat.sum(axis=1) + total_mat.sum(axis=0) - np.diag(total_mat))
|
|
||||||
|
|
||||||
return all_acc, acc, iou
|
|
||||||
|
|
||||||
|
|
||||||
def test_mean_iou():
|
|
||||||
pred_size = (10, 30, 30)
|
|
||||||
num_classes = 19
|
|
||||||
ignore_index = 255
|
|
||||||
results = np.random.randint(0, num_classes, size=pred_size)
|
|
||||||
label = np.random.randint(0, num_classes, size=pred_size)
|
|
||||||
label[:, 2, 5:10] = ignore_index
|
|
||||||
all_acc, acc, iou = mean_iou(results, label, num_classes, ignore_index)
|
|
||||||
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
|
|
||||||
ignore_index)
|
|
||||||
assert all_acc == all_acc_l
|
|
||||||
assert np.allclose(acc, acc_l)
|
|
||||||
assert np.allclose(iou, iou_l)
|
|
||||||
|
|
||||||
results = np.random.randint(0, 5, size=pred_size)
|
|
||||||
label = np.random.randint(0, 4, size=pred_size)
|
|
||||||
all_acc, acc, iou = mean_iou(
|
|
||||||
results, label, num_classes, ignore_index=255, nan_to_num=-1)
|
|
||||||
assert acc[-1] == -1
|
|
||||||
assert iou[-1] == -1
|
|
166
tests/test_metrics.py
Normal file
166
tests/test_metrics.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mmseg.core.evaluation import eval_metrics, mean_dice, mean_iou
|
||||||
|
|
||||||
|
|
||||||
|
def get_confusion_matrix(pred_label, label, num_classes, ignore_index):
|
||||||
|
"""Intersection over Union
|
||||||
|
Args:
|
||||||
|
pred_label (np.ndarray): 2D predict map
|
||||||
|
label (np.ndarray): label 2D label map
|
||||||
|
num_classes (int): number of categories
|
||||||
|
ignore_index (int): index ignore in evaluation
|
||||||
|
"""
|
||||||
|
|
||||||
|
mask = (label != ignore_index)
|
||||||
|
pred_label = pred_label[mask]
|
||||||
|
label = label[mask]
|
||||||
|
|
||||||
|
n = num_classes
|
||||||
|
inds = n * label + pred_label
|
||||||
|
|
||||||
|
mat = np.bincount(inds, minlength=n**2).reshape(n, n)
|
||||||
|
|
||||||
|
return mat
|
||||||
|
|
||||||
|
|
||||||
|
# This func is deprecated since it's not memory efficient
|
||||||
|
def legacy_mean_iou(results, gt_seg_maps, num_classes, ignore_index):
|
||||||
|
num_imgs = len(results)
|
||||||
|
assert len(gt_seg_maps) == num_imgs
|
||||||
|
total_mat = np.zeros((num_classes, num_classes), dtype=np.float)
|
||||||
|
for i in range(num_imgs):
|
||||||
|
mat = get_confusion_matrix(
|
||||||
|
results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index)
|
||||||
|
total_mat += mat
|
||||||
|
all_acc = np.diag(total_mat).sum() / total_mat.sum()
|
||||||
|
acc = np.diag(total_mat) / total_mat.sum(axis=1)
|
||||||
|
iou = np.diag(total_mat) / (
|
||||||
|
total_mat.sum(axis=1) + total_mat.sum(axis=0) - np.diag(total_mat))
|
||||||
|
|
||||||
|
return all_acc, acc, iou
|
||||||
|
|
||||||
|
|
||||||
|
# This func is deprecated since it's not memory efficient
|
||||||
|
def legacy_mean_dice(results, gt_seg_maps, num_classes, ignore_index):
|
||||||
|
num_imgs = len(results)
|
||||||
|
assert len(gt_seg_maps) == num_imgs
|
||||||
|
total_mat = np.zeros((num_classes, num_classes), dtype=np.float)
|
||||||
|
for i in range(num_imgs):
|
||||||
|
mat = get_confusion_matrix(
|
||||||
|
results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index)
|
||||||
|
total_mat += mat
|
||||||
|
all_acc = np.diag(total_mat).sum() / total_mat.sum()
|
||||||
|
acc = np.diag(total_mat) / total_mat.sum(axis=1)
|
||||||
|
dice = 2 * np.diag(total_mat) / (
|
||||||
|
total_mat.sum(axis=1) + total_mat.sum(axis=0))
|
||||||
|
|
||||||
|
return all_acc, acc, dice
|
||||||
|
|
||||||
|
|
||||||
|
def test_metrics():
|
||||||
|
pred_size = (10, 30, 30)
|
||||||
|
num_classes = 19
|
||||||
|
ignore_index = 255
|
||||||
|
results = np.random.randint(0, num_classes, size=pred_size)
|
||||||
|
label = np.random.randint(0, num_classes, size=pred_size)
|
||||||
|
label[:, 2, 5:10] = ignore_index
|
||||||
|
all_acc, acc, iou = eval_metrics(
|
||||||
|
results, label, num_classes, ignore_index, metrics='mIoU')
|
||||||
|
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
|
||||||
|
ignore_index)
|
||||||
|
assert all_acc == all_acc_l
|
||||||
|
assert np.allclose(acc, acc_l)
|
||||||
|
assert np.allclose(iou, iou_l)
|
||||||
|
|
||||||
|
all_acc, acc, dice = eval_metrics(
|
||||||
|
results, label, num_classes, ignore_index, metrics='mDice')
|
||||||
|
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
|
||||||
|
ignore_index)
|
||||||
|
assert all_acc == all_acc_l
|
||||||
|
assert np.allclose(acc, acc_l)
|
||||||
|
assert np.allclose(dice, dice_l)
|
||||||
|
|
||||||
|
all_acc, acc, iou, dice = eval_metrics(
|
||||||
|
results, label, num_classes, ignore_index, metrics=['mIoU', 'mDice'])
|
||||||
|
assert all_acc == all_acc_l
|
||||||
|
assert np.allclose(acc, acc_l)
|
||||||
|
assert np.allclose(iou, iou_l)
|
||||||
|
assert np.allclose(dice, dice_l)
|
||||||
|
|
||||||
|
results = np.random.randint(0, 5, size=pred_size)
|
||||||
|
label = np.random.randint(0, 4, size=pred_size)
|
||||||
|
all_acc, acc, iou = eval_metrics(
|
||||||
|
results,
|
||||||
|
label,
|
||||||
|
num_classes,
|
||||||
|
ignore_index=255,
|
||||||
|
metrics='mIoU',
|
||||||
|
nan_to_num=-1)
|
||||||
|
assert acc[-1] == -1
|
||||||
|
assert iou[-1] == -1
|
||||||
|
|
||||||
|
all_acc, acc, dice = eval_metrics(
|
||||||
|
results,
|
||||||
|
label,
|
||||||
|
num_classes,
|
||||||
|
ignore_index=255,
|
||||||
|
metrics='mDice',
|
||||||
|
nan_to_num=-1)
|
||||||
|
assert acc[-1] == -1
|
||||||
|
assert dice[-1] == -1
|
||||||
|
|
||||||
|
all_acc, acc, dice, iou = eval_metrics(
|
||||||
|
results,
|
||||||
|
label,
|
||||||
|
num_classes,
|
||||||
|
ignore_index=255,
|
||||||
|
metrics=['mDice', 'mIoU'],
|
||||||
|
nan_to_num=-1)
|
||||||
|
assert acc[-1] == -1
|
||||||
|
assert dice[-1] == -1
|
||||||
|
assert iou[-1] == -1
|
||||||
|
|
||||||
|
|
||||||
|
def test_mean_iou():
|
||||||
|
pred_size = (10, 30, 30)
|
||||||
|
num_classes = 19
|
||||||
|
ignore_index = 255
|
||||||
|
results = np.random.randint(0, num_classes, size=pred_size)
|
||||||
|
label = np.random.randint(0, num_classes, size=pred_size)
|
||||||
|
label[:, 2, 5:10] = ignore_index
|
||||||
|
all_acc, acc, iou = mean_iou(results, label, num_classes, ignore_index)
|
||||||
|
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
|
||||||
|
ignore_index)
|
||||||
|
assert all_acc == all_acc_l
|
||||||
|
assert np.allclose(acc, acc_l)
|
||||||
|
assert np.allclose(iou, iou_l)
|
||||||
|
|
||||||
|
results = np.random.randint(0, 5, size=pred_size)
|
||||||
|
label = np.random.randint(0, 4, size=pred_size)
|
||||||
|
all_acc, acc, iou = mean_iou(
|
||||||
|
results, label, num_classes, ignore_index=255, nan_to_num=-1)
|
||||||
|
assert acc[-1] == -1
|
||||||
|
assert iou[-1] == -1
|
||||||
|
|
||||||
|
|
||||||
|
def test_mean_dice():
|
||||||
|
pred_size = (10, 30, 30)
|
||||||
|
num_classes = 19
|
||||||
|
ignore_index = 255
|
||||||
|
results = np.random.randint(0, num_classes, size=pred_size)
|
||||||
|
label = np.random.randint(0, num_classes, size=pred_size)
|
||||||
|
label[:, 2, 5:10] = ignore_index
|
||||||
|
all_acc, acc, iou = mean_dice(results, label, num_classes, ignore_index)
|
||||||
|
all_acc_l, acc_l, iou_l = legacy_mean_dice(results, label, num_classes,
|
||||||
|
ignore_index)
|
||||||
|
assert all_acc == all_acc_l
|
||||||
|
assert np.allclose(acc, acc_l)
|
||||||
|
assert np.allclose(iou, iou_l)
|
||||||
|
|
||||||
|
results = np.random.randint(0, 5, size=pred_size)
|
||||||
|
label = np.random.randint(0, 4, size=pred_size)
|
||||||
|
all_acc, acc, iou = mean_dice(
|
||||||
|
results, label, num_classes, ignore_index=255, nan_to_num=-1)
|
||||||
|
assert acc[-1] == -1
|
||||||
|
assert iou[-1] == -1
|
Loading…
x
Reference in New Issue
Block a user