[Feature] Support calculate confusion matrix and plot it. (#1287)
* [Feature] Support calculate confusion matrix and plot it. * Fix keepdim * Update confusion_matrix tools and the plot graph. * Revert accidental modification. * Update docstring * Move confusion matrix tool topull/1365/head
parent
841256b630
commit
b4ee9d2848
|
@ -22,6 +22,7 @@ Single Label Metric
|
|||
|
||||
Accuracy
|
||||
SingleLabelMetric
|
||||
ConfusionMatrix
|
||||
|
||||
Multi Label Metric
|
||||
----------------------
|
||||
|
|
|
@ -2,11 +2,11 @@
|
|||
from .multi_label import AveragePrecision, MultiLabelMetric
|
||||
from .multi_task import MultiTasksMetric
|
||||
from .retrieval import RetrievalRecall
|
||||
from .single_label import Accuracy, SingleLabelMetric
|
||||
from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric
|
||||
from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric
|
||||
|
||||
__all__ = [
|
||||
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision',
|
||||
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
|
||||
'RetrievalRecall'
|
||||
'ConfusionMatrix', 'RetrievalRecall'
|
||||
]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from itertools import product
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mmengine
|
||||
|
@ -26,7 +27,7 @@ def _precision_recall_f1_support(pred_positive, gt_positive, average):
|
|||
f1_score, support."""
|
||||
average_options = ['micro', 'macro', None]
|
||||
assert average in average_options, 'Invalid `average` argument, ' \
|
||||
f'please specicy from {average_options}.'
|
||||
f'please specify from {average_options}.'
|
||||
|
||||
# ignore -1 target such as difficult sample that is not wanted
|
||||
# in evaluation results.
|
||||
|
@ -398,7 +399,7 @@ class SingleLabelMetric(BaseMetric):
|
|||
for item in items:
|
||||
assert item in ['precision', 'recall', 'f1-score', 'support'], \
|
||||
f'The metric {item} is not supported by `SingleLabelMetric`,' \
|
||||
' please specicy from "precision", "recall", "f1-score" and ' \
|
||||
' please specify from "precision", "recall", "f1-score" and ' \
|
||||
'"support".'
|
||||
self.items = tuple(items)
|
||||
self.average = average
|
||||
|
@ -549,7 +550,7 @@ class SingleLabelMetric(BaseMetric):
|
|||
"""
|
||||
average_options = ['micro', 'macro', None]
|
||||
assert average in average_options, 'Invalid `average` argument, ' \
|
||||
f'please specicy from {average_options}.'
|
||||
f'please specify from {average_options}.'
|
||||
|
||||
pred = to_tensor(pred)
|
||||
target = to_tensor(target).to(torch.int64)
|
||||
|
@ -559,7 +560,7 @@ class SingleLabelMetric(BaseMetric):
|
|||
|
||||
if pred.ndim == 1:
|
||||
assert num_classes is not None, \
|
||||
'Please specicy the `num_classes` if the `pred` is labels ' \
|
||||
'Please specify the `num_classes` if the `pred` is labels ' \
|
||||
'intead of scores.'
|
||||
gt_positive = F.one_hot(target.flatten(), num_classes)
|
||||
pred_positive = F.one_hot(pred.to(torch.int64), num_classes)
|
||||
|
@ -584,3 +585,198 @@ class SingleLabelMetric(BaseMetric):
|
|||
average))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class ConfusionMatrix(BaseMetric):
|
||||
r"""A metric to calculate confusion matrix for single-label tasks.
|
||||
|
||||
Args:
|
||||
num_classes (int, optional): The number of classes. Defaults to None.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
|
||||
Examples:
|
||||
|
||||
1. The basic usage.
|
||||
|
||||
>>> import torch
|
||||
>>> from mmcls.evaluation import ConfusionMatrix
|
||||
>>> y_pred = [0, 1, 1, 3]
|
||||
>>> y_true = [0, 2, 1, 3]
|
||||
>>> ConfusionMatrix.calculate(y_pred, y_true, num_classes=4)
|
||||
tensor([[1, 0, 0, 0],
|
||||
[0, 1, 0, 0],
|
||||
[0, 1, 0, 0],
|
||||
[0, 0, 0, 1]])
|
||||
>>> # plot the confusion matrix
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> y_score = torch.rand((1000, 10))
|
||||
>>> y_true = torch.randint(10, (1000, ))
|
||||
>>> matrix = ConfusionMatrix.calculate(y_score, y_true)
|
||||
>>> ConfusionMatrix().plot(matrix)
|
||||
>>> plt.show()
|
||||
|
||||
2. In the config file
|
||||
|
||||
.. code:: python
|
||||
|
||||
val_evaluator = dict(type='ConfusionMatrix')
|
||||
test_evaluator = dict(type='ConfusionMatrix')
|
||||
""" # noqa: E501
|
||||
default_prefix = 'confusion_matrix'
|
||||
|
||||
def __init__(self,
|
||||
num_classes: Optional[int] = None,
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device, prefix)
|
||||
|
||||
self.num_classes = num_classes
|
||||
|
||||
def process(self, data_batch, data_samples: Sequence[dict]) -> None:
|
||||
for data_sample in data_samples:
|
||||
pred = data_sample['pred_label']
|
||||
gt_label = data_sample['gt_label']['label']
|
||||
if 'score' in pred:
|
||||
pred_label = pred['score'].argmax(dim=0, keepdim=True)
|
||||
self.num_classes = pred['score'].size(0)
|
||||
else:
|
||||
pred_label = pred['label']
|
||||
|
||||
self.results.append({
|
||||
'pred_label': pred_label,
|
||||
'gt_label': gt_label
|
||||
})
|
||||
|
||||
def compute_metrics(self, results: list) -> dict:
|
||||
pred_labels = []
|
||||
gt_labels = []
|
||||
for result in results:
|
||||
pred_labels.append(result['pred_label'])
|
||||
gt_labels.append(result['gt_label'])
|
||||
confusion_matrix = ConfusionMatrix.calculate(
|
||||
torch.cat(pred_labels),
|
||||
torch.cat(gt_labels),
|
||||
num_classes=self.num_classes)
|
||||
return {'result': confusion_matrix}
|
||||
|
||||
@staticmethod
|
||||
def calculate(pred, target, num_classes=None) -> dict:
|
||||
"""Calculate the confusion matrix for single-label task.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor | np.ndarray | Sequence): The prediction
|
||||
results. It can be labels (N, ), or scores of every
|
||||
class (N, C).
|
||||
target (torch.Tensor | np.ndarray | Sequence): The target of
|
||||
each prediction with shape (N, ).
|
||||
num_classes (Optional, int): The number of classes. If the ``pred``
|
||||
is label instead of scores, this argument is required.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The confusion matrix.
|
||||
"""
|
||||
pred = to_tensor(pred)
|
||||
target_label = to_tensor(target).int()
|
||||
|
||||
assert pred.size(0) == target_label.size(0), \
|
||||
f"The size of pred ({pred.size(0)}) doesn't match "\
|
||||
f'the target ({target_label.size(0)}).'
|
||||
assert target_label.ndim == 1
|
||||
|
||||
if pred.ndim == 1:
|
||||
assert num_classes is not None, \
|
||||
'Please specify the `num_classes` if the `pred` is labels ' \
|
||||
'intead of scores.'
|
||||
pred_label = pred
|
||||
else:
|
||||
num_classes = num_classes or pred.size(1)
|
||||
pred_label = torch.argmax(pred, dim=1).flatten()
|
||||
|
||||
with torch.no_grad():
|
||||
indices = num_classes * target_label + pred_label
|
||||
matrix = torch.bincount(indices, minlength=num_classes**2)
|
||||
matrix = matrix.reshape(num_classes, num_classes)
|
||||
|
||||
return matrix
|
||||
|
||||
@staticmethod
|
||||
def plot(confusion_matrix: torch.Tensor,
|
||||
include_values: bool = False,
|
||||
cmap: str = 'viridis',
|
||||
classes: Optional[List[str]] = None,
|
||||
colorbar: bool = True,
|
||||
show: bool = True):
|
||||
"""Draw a confusion matrix by matplotlib.
|
||||
|
||||
Modified from `Scikit-Learn
|
||||
<https://github.com/scikit-learn/scikit-learn/blob/dc580a8ef/sklearn/metrics/_plot/confusion_matrix.py#L81>`_
|
||||
|
||||
Args:
|
||||
confusion_matrix (torch.Tensor): The confusion matrix to draw.
|
||||
include_values (bool): Whether to draw the values in the figure.
|
||||
Defaults to False.
|
||||
cmap (str): The color map to use. Defaults to use "viridis".
|
||||
classes (list[str], optional): The names of categories.
|
||||
Defaults to None, which means to use index number.
|
||||
colorbar (bool): Whether to show the colorbar. Defaults to True.
|
||||
show (bool): Whether to show the figure immediately.
|
||||
Defaults to True.
|
||||
""" # noqa: E501
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 10))
|
||||
|
||||
num_classes = confusion_matrix.size(0)
|
||||
|
||||
im_ = ax.imshow(confusion_matrix, interpolation='nearest', cmap=cmap)
|
||||
text_ = None
|
||||
cmap_min, cmap_max = im_.cmap(0), im_.cmap(1.0)
|
||||
|
||||
if include_values:
|
||||
text_ = np.empty_like(confusion_matrix, dtype=object)
|
||||
|
||||
# print text with appropriate color depending on background
|
||||
thresh = (confusion_matrix.max() + confusion_matrix.min()) / 2.0
|
||||
|
||||
for i, j in product(range(num_classes), range(num_classes)):
|
||||
color = cmap_max if confusion_matrix[i,
|
||||
j] < thresh else cmap_min
|
||||
|
||||
text_cm = format(confusion_matrix[i, j], '.2g')
|
||||
text_d = format(confusion_matrix[i, j], 'd')
|
||||
if len(text_d) < len(text_cm):
|
||||
text_cm = text_d
|
||||
|
||||
text_[i, j] = ax.text(
|
||||
j, i, text_cm, ha='center', va='center', color=color)
|
||||
|
||||
display_labels = classes or np.arange(num_classes)
|
||||
|
||||
if colorbar:
|
||||
fig.colorbar(im_, ax=ax)
|
||||
ax.set(
|
||||
xticks=np.arange(num_classes),
|
||||
yticks=np.arange(num_classes),
|
||||
xticklabels=display_labels,
|
||||
yticklabels=display_labels,
|
||||
ylabel='True label',
|
||||
xlabel='Predicted label',
|
||||
)
|
||||
ax.invert_yaxis()
|
||||
ax.xaxis.tick_top()
|
||||
|
||||
ax.set_ylim((num_classes - 0.5, -0.5))
|
||||
# Automatically rotate the x labels.
|
||||
fig.autofmt_xdate(ha='center')
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
return fig
|
||||
|
|
|
@ -5,7 +5,8 @@ from unittest import TestCase
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmcls.evaluation.metrics import Accuracy, SingleLabelMetric
|
||||
from mmcls.evaluation.metrics import (Accuracy, ConfusionMatrix,
|
||||
SingleLabelMetric)
|
||||
from mmcls.registry import METRICS
|
||||
from mmcls.structures import ClsDataSample
|
||||
|
||||
|
@ -296,3 +297,113 @@ class TestSingleLabel(TestCase):
|
|||
torch.testing.assert_allclose(tensor, value, **kwarg)
|
||||
except AssertionError as e:
|
||||
self.fail(self._formatMessage(msg, str(e)))
|
||||
|
||||
|
||||
class TestConfusionMatrix(TestCase):
|
||||
|
||||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
pred = [
|
||||
ClsDataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
|
||||
k).to_dict() for i, j, k in zip([
|
||||
torch.tensor([0.7, 0.0, 0.3]),
|
||||
torch.tensor([0.5, 0.2, 0.3]),
|
||||
torch.tensor([0.4, 0.5, 0.1]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
], [0, 0, 1, 2, 2, 2], [0, 0, 1, 2, 1, 0])
|
||||
]
|
||||
|
||||
# Test with score (use score instead of label if score exists)
|
||||
metric = METRICS.build(dict(type='ConfusionMatrix'))
|
||||
metric.process(None, pred)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertTensorEqual(
|
||||
res['confusion_matrix/result'],
|
||||
torch.tensor([
|
||||
[2, 0, 1],
|
||||
[0, 1, 1],
|
||||
[0, 0, 1],
|
||||
]))
|
||||
|
||||
# Test with label
|
||||
for sample in pred:
|
||||
del sample['pred_label']['score']
|
||||
metric = METRICS.build(dict(type='ConfusionMatrix'))
|
||||
metric.process(None, pred)
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'Please specify the `num_classes`'):
|
||||
metric.evaluate(6)
|
||||
|
||||
metric = METRICS.build(dict(type='ConfusionMatrix', num_classes=3))
|
||||
metric.process(None, pred)
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertTensorEqual(
|
||||
res['confusion_matrix/result'],
|
||||
torch.tensor([
|
||||
[2, 0, 1],
|
||||
[0, 1, 1],
|
||||
[0, 0, 1],
|
||||
]))
|
||||
|
||||
def test_calculate(self):
|
||||
y_true = np.array([0, 0, 1, 2, 1, 0])
|
||||
y_label = torch.tensor([0, 0, 1, 2, 2, 2])
|
||||
y_score = [
|
||||
[0.7, 0.0, 0.3],
|
||||
[0.5, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.1],
|
||||
[0.0, 0.0, 1.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
]
|
||||
|
||||
# Test with score
|
||||
cm = ConfusionMatrix.calculate(y_score, y_true)
|
||||
self.assertIsInstance(cm, torch.Tensor)
|
||||
self.assertTensorEqual(
|
||||
cm, torch.tensor([
|
||||
[2, 0, 1],
|
||||
[0, 1, 1],
|
||||
[0, 0, 1],
|
||||
]))
|
||||
|
||||
# Test with label
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'Please specify the `num_classes`'):
|
||||
ConfusionMatrix.calculate(y_label, y_true)
|
||||
|
||||
cm = ConfusionMatrix.calculate(y_label, y_true, num_classes=3)
|
||||
self.assertIsInstance(cm, torch.Tensor)
|
||||
self.assertTensorEqual(
|
||||
cm, torch.tensor([
|
||||
[2, 0, 1],
|
||||
[0, 1, 1],
|
||||
[0, 0, 1],
|
||||
]))
|
||||
|
||||
# Test with invalid inputs
|
||||
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
|
||||
ConfusionMatrix.calculate(y_label, 'hi')
|
||||
|
||||
def test_plot(self):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
cm = torch.tensor([[2, 0, 1], [0, 1, 1], [0, 0, 1]])
|
||||
fig = ConfusionMatrix.plot(cm, include_values=True, show=False)
|
||||
|
||||
self.assertIsInstance(fig, plt.Figure)
|
||||
|
||||
def assertTensorEqual(self,
|
||||
tensor: torch.Tensor,
|
||||
value: float,
|
||||
msg=None,
|
||||
**kwarg):
|
||||
tensor = tensor.to(torch.float32)
|
||||
value = torch.tensor(value).float()
|
||||
try:
|
||||
torch.testing.assert_allclose(tensor, value, **kwarg)
|
||||
except AssertionError as e:
|
||||
self.fail(self._formatMessage(msg, str(e)))
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import tempfile
|
||||
|
||||
import mmengine
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmcls.evaluation import ConfusionMatrix
|
||||
from mmcls.registry import DATASETS
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Eval a checkpoint and draw the confusion matrix.')
|
||||
parser.add_argument('config', help='test config file path')
|
||||
parser.add_argument(
|
||||
'ckpt_or_result',
|
||||
type=str,
|
||||
help='The checkpoint file (.pth) or '
|
||||
'dumpped predictions pickle file (.pkl).')
|
||||
parser.add_argument('--out', help='the file to save the confusion matrix.')
|
||||
parser.add_argument(
|
||||
'--show',
|
||||
action='store_true',
|
||||
help='whether to display the metric result by matplotlib if supports.')
|
||||
parser.add_argument(
|
||||
'--show-path', type=str, help='Path to save the visualization image.')
|
||||
parser.add_argument(
|
||||
'--include-values',
|
||||
action='store_true',
|
||||
help='To draw the values in the figure.')
|
||||
parser.add_argument(
|
||||
'--cmap',
|
||||
type=str,
|
||||
default='viridis',
|
||||
help='The color map to use. Defaults to "viridis".')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# register all modules in mmcls into the registries
|
||||
# do not init the default scope here because it will be init in the runner
|
||||
register_all_modules(init_default_scope=False)
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
if args.ckpt_or_result.endswith('.pth'):
|
||||
# Set confusion matrix as the metric.
|
||||
cfg.test_evaluator = dict(type='ConfusionMatrix')
|
||||
|
||||
cfg.load_from = str(args.ckpt_or_result)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg.work_dir = tmpdir
|
||||
runner = Runner.from_cfg(cfg)
|
||||
classes = runner.test_loop.dataloader.dataset.metainfo.get(
|
||||
'classes')
|
||||
cm = runner.test()['confusion_matrix/result']
|
||||
else:
|
||||
predictions = mmengine.load(args.ckpt_or_result)
|
||||
evaluator = Evaluator(ConfusionMatrix())
|
||||
metrics = evaluator.offline_evaluate(predictions, None)
|
||||
cm = metrics['confusion_matrix/result']
|
||||
try:
|
||||
# Try to build the dataset.
|
||||
dataset = DATASETS.build({
|
||||
**cfg.test_dataloader.dataset, 'pipeline': []
|
||||
})
|
||||
classes = dataset.metainfo.get('classes')
|
||||
except Exception:
|
||||
classes = None
|
||||
|
||||
if args.out is not None:
|
||||
mmengine.dump(cm, args.out)
|
||||
|
||||
if args.show or args.show_path is not None:
|
||||
fig = ConfusionMatrix.plot(
|
||||
cm,
|
||||
show=args.show,
|
||||
classes=classes,
|
||||
include_values=args.include_values,
|
||||
cmap=args.cmap)
|
||||
if args.show_path is not None:
|
||||
fig.savefig(args.show_path)
|
||||
print(f'The confusion matrix is saved at {args.show_path}.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue