[Feature] Support calculate confusion matrix and plot it. (#1287)

* [Feature] Support calculate confusion matrix and plot it.

* Fix keepdim

* Update confusion_matrix tools and the plot graph.

* Revert accidental modification.

* Update docstring

* Move confusion matrix tool to
pull/1365/head
Ma Zerun 2023-02-14 12:58:11 +08:00 committed by GitHub
parent 841256b630
commit b4ee9d2848
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 423 additions and 7 deletions

View File

@ -22,6 +22,7 @@ Single Label Metric
Accuracy
SingleLabelMetric
ConfusionMatrix
Multi Label Metric
----------------------

View File

@ -2,11 +2,11 @@
from .multi_label import AveragePrecision, MultiLabelMetric
from .multi_task import MultiTasksMetric
from .retrieval import RetrievalRecall
from .single_label import Accuracy, SingleLabelMetric
from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric
from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric
__all__ = [
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision',
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
'RetrievalRecall'
'ConfusionMatrix', 'RetrievalRecall'
]

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from itertools import product
from typing import List, Optional, Sequence, Union
import mmengine
@ -26,7 +27,7 @@ def _precision_recall_f1_support(pred_positive, gt_positive, average):
f1_score, support."""
average_options = ['micro', 'macro', None]
assert average in average_options, 'Invalid `average` argument, ' \
f'please specicy from {average_options}.'
f'please specify from {average_options}.'
# ignore -1 target such as difficult sample that is not wanted
# in evaluation results.
@ -398,7 +399,7 @@ class SingleLabelMetric(BaseMetric):
for item in items:
assert item in ['precision', 'recall', 'f1-score', 'support'], \
f'The metric {item} is not supported by `SingleLabelMetric`,' \
' please specicy from "precision", "recall", "f1-score" and ' \
' please specify from "precision", "recall", "f1-score" and ' \
'"support".'
self.items = tuple(items)
self.average = average
@ -549,7 +550,7 @@ class SingleLabelMetric(BaseMetric):
"""
average_options = ['micro', 'macro', None]
assert average in average_options, 'Invalid `average` argument, ' \
f'please specicy from {average_options}.'
f'please specify from {average_options}.'
pred = to_tensor(pred)
target = to_tensor(target).to(torch.int64)
@ -559,7 +560,7 @@ class SingleLabelMetric(BaseMetric):
if pred.ndim == 1:
assert num_classes is not None, \
'Please specicy the `num_classes` if the `pred` is labels ' \
'Please specify the `num_classes` if the `pred` is labels ' \
'intead of scores.'
gt_positive = F.one_hot(target.flatten(), num_classes)
pred_positive = F.one_hot(pred.to(torch.int64), num_classes)
@ -584,3 +585,198 @@ class SingleLabelMetric(BaseMetric):
average))
return results
@METRICS.register_module()
class ConfusionMatrix(BaseMetric):
r"""A metric to calculate confusion matrix for single-label tasks.
Args:
num_classes (int, optional): The number of classes. Defaults to None.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
Examples:
1. The basic usage.
>>> import torch
>>> from mmcls.evaluation import ConfusionMatrix
>>> y_pred = [0, 1, 1, 3]
>>> y_true = [0, 2, 1, 3]
>>> ConfusionMatrix.calculate(y_pred, y_true, num_classes=4)
tensor([[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1]])
>>> # plot the confusion matrix
>>> import matplotlib.pyplot as plt
>>> y_score = torch.rand((1000, 10))
>>> y_true = torch.randint(10, (1000, ))
>>> matrix = ConfusionMatrix.calculate(y_score, y_true)
>>> ConfusionMatrix().plot(matrix)
>>> plt.show()
2. In the config file
.. code:: python
val_evaluator = dict(type='ConfusionMatrix')
test_evaluator = dict(type='ConfusionMatrix')
""" # noqa: E501
default_prefix = 'confusion_matrix'
def __init__(self,
num_classes: Optional[int] = None,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device, prefix)
self.num_classes = num_classes
def process(self, data_batch, data_samples: Sequence[dict]) -> None:
for data_sample in data_samples:
pred = data_sample['pred_label']
gt_label = data_sample['gt_label']['label']
if 'score' in pred:
pred_label = pred['score'].argmax(dim=0, keepdim=True)
self.num_classes = pred['score'].size(0)
else:
pred_label = pred['label']
self.results.append({
'pred_label': pred_label,
'gt_label': gt_label
})
def compute_metrics(self, results: list) -> dict:
pred_labels = []
gt_labels = []
for result in results:
pred_labels.append(result['pred_label'])
gt_labels.append(result['gt_label'])
confusion_matrix = ConfusionMatrix.calculate(
torch.cat(pred_labels),
torch.cat(gt_labels),
num_classes=self.num_classes)
return {'result': confusion_matrix}
@staticmethod
def calculate(pred, target, num_classes=None) -> dict:
"""Calculate the confusion matrix for single-label task.
Args:
pred (torch.Tensor | np.ndarray | Sequence): The prediction
results. It can be labels (N, ), or scores of every
class (N, C).
target (torch.Tensor | np.ndarray | Sequence): The target of
each prediction with shape (N, ).
num_classes (Optional, int): The number of classes. If the ``pred``
is label instead of scores, this argument is required.
Defaults to None.
Returns:
torch.Tensor: The confusion matrix.
"""
pred = to_tensor(pred)
target_label = to_tensor(target).int()
assert pred.size(0) == target_label.size(0), \
f"The size of pred ({pred.size(0)}) doesn't match "\
f'the target ({target_label.size(0)}).'
assert target_label.ndim == 1
if pred.ndim == 1:
assert num_classes is not None, \
'Please specify the `num_classes` if the `pred` is labels ' \
'intead of scores.'
pred_label = pred
else:
num_classes = num_classes or pred.size(1)
pred_label = torch.argmax(pred, dim=1).flatten()
with torch.no_grad():
indices = num_classes * target_label + pred_label
matrix = torch.bincount(indices, minlength=num_classes**2)
matrix = matrix.reshape(num_classes, num_classes)
return matrix
@staticmethod
def plot(confusion_matrix: torch.Tensor,
include_values: bool = False,
cmap: str = 'viridis',
classes: Optional[List[str]] = None,
colorbar: bool = True,
show: bool = True):
"""Draw a confusion matrix by matplotlib.
Modified from `Scikit-Learn
<https://github.com/scikit-learn/scikit-learn/blob/dc580a8ef/sklearn/metrics/_plot/confusion_matrix.py#L81>`_
Args:
confusion_matrix (torch.Tensor): The confusion matrix to draw.
include_values (bool): Whether to draw the values in the figure.
Defaults to False.
cmap (str): The color map to use. Defaults to use "viridis".
classes (list[str], optional): The names of categories.
Defaults to None, which means to use index number.
colorbar (bool): Whether to show the colorbar. Defaults to True.
show (bool): Whether to show the figure immediately.
Defaults to True.
""" # noqa: E501
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10, 10))
num_classes = confusion_matrix.size(0)
im_ = ax.imshow(confusion_matrix, interpolation='nearest', cmap=cmap)
text_ = None
cmap_min, cmap_max = im_.cmap(0), im_.cmap(1.0)
if include_values:
text_ = np.empty_like(confusion_matrix, dtype=object)
# print text with appropriate color depending on background
thresh = (confusion_matrix.max() + confusion_matrix.min()) / 2.0
for i, j in product(range(num_classes), range(num_classes)):
color = cmap_max if confusion_matrix[i,
j] < thresh else cmap_min
text_cm = format(confusion_matrix[i, j], '.2g')
text_d = format(confusion_matrix[i, j], 'd')
if len(text_d) < len(text_cm):
text_cm = text_d
text_[i, j] = ax.text(
j, i, text_cm, ha='center', va='center', color=color)
display_labels = classes or np.arange(num_classes)
if colorbar:
fig.colorbar(im_, ax=ax)
ax.set(
xticks=np.arange(num_classes),
yticks=np.arange(num_classes),
xticklabels=display_labels,
yticklabels=display_labels,
ylabel='True label',
xlabel='Predicted label',
)
ax.invert_yaxis()
ax.xaxis.tick_top()
ax.set_ylim((num_classes - 0.5, -0.5))
# Automatically rotate the x labels.
fig.autofmt_xdate(ha='center')
if show:
plt.show()
return fig

View File

@ -5,7 +5,8 @@ from unittest import TestCase
import numpy as np
import torch
from mmcls.evaluation.metrics import Accuracy, SingleLabelMetric
from mmcls.evaluation.metrics import (Accuracy, ConfusionMatrix,
SingleLabelMetric)
from mmcls.registry import METRICS
from mmcls.structures import ClsDataSample
@ -296,3 +297,113 @@ class TestSingleLabel(TestCase):
torch.testing.assert_allclose(tensor, value, **kwarg)
except AssertionError as e:
self.fail(self._formatMessage(msg, str(e)))
class TestConfusionMatrix(TestCase):
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
pred = [
ClsDataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
k).to_dict() for i, j, k in zip([
torch.tensor([0.7, 0.0, 0.3]),
torch.tensor([0.5, 0.2, 0.3]),
torch.tensor([0.4, 0.5, 0.1]),
torch.tensor([0.0, 0.0, 1.0]),
torch.tensor([0.0, 0.0, 1.0]),
torch.tensor([0.0, 0.0, 1.0]),
], [0, 0, 1, 2, 2, 2], [0, 0, 1, 2, 1, 0])
]
# Test with score (use score instead of label if score exists)
metric = METRICS.build(dict(type='ConfusionMatrix'))
metric.process(None, pred)
res = metric.evaluate(6)
self.assertIsInstance(res, dict)
self.assertTensorEqual(
res['confusion_matrix/result'],
torch.tensor([
[2, 0, 1],
[0, 1, 1],
[0, 0, 1],
]))
# Test with label
for sample in pred:
del sample['pred_label']['score']
metric = METRICS.build(dict(type='ConfusionMatrix'))
metric.process(None, pred)
with self.assertRaisesRegex(AssertionError,
'Please specify the `num_classes`'):
metric.evaluate(6)
metric = METRICS.build(dict(type='ConfusionMatrix', num_classes=3))
metric.process(None, pred)
self.assertIsInstance(res, dict)
self.assertTensorEqual(
res['confusion_matrix/result'],
torch.tensor([
[2, 0, 1],
[0, 1, 1],
[0, 0, 1],
]))
def test_calculate(self):
y_true = np.array([0, 0, 1, 2, 1, 0])
y_label = torch.tensor([0, 0, 1, 2, 2, 2])
y_score = [
[0.7, 0.0, 0.3],
[0.5, 0.2, 0.3],
[0.4, 0.5, 0.1],
[0.0, 0.0, 1.0],
[0.0, 0.0, 1.0],
[0.0, 0.0, 1.0],
]
# Test with score
cm = ConfusionMatrix.calculate(y_score, y_true)
self.assertIsInstance(cm, torch.Tensor)
self.assertTensorEqual(
cm, torch.tensor([
[2, 0, 1],
[0, 1, 1],
[0, 0, 1],
]))
# Test with label
with self.assertRaisesRegex(AssertionError,
'Please specify the `num_classes`'):
ConfusionMatrix.calculate(y_label, y_true)
cm = ConfusionMatrix.calculate(y_label, y_true, num_classes=3)
self.assertIsInstance(cm, torch.Tensor)
self.assertTensorEqual(
cm, torch.tensor([
[2, 0, 1],
[0, 1, 1],
[0, 0, 1],
]))
# Test with invalid inputs
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
ConfusionMatrix.calculate(y_label, 'hi')
def test_plot(self):
import matplotlib.pyplot as plt
cm = torch.tensor([[2, 0, 1], [0, 1, 1], [0, 0, 1]])
fig = ConfusionMatrix.plot(cm, include_values=True, show=False)
self.assertIsInstance(fig, plt.Figure)
def assertTensorEqual(self,
tensor: torch.Tensor,
value: float,
msg=None,
**kwarg):
tensor = tensor.to(torch.float32)
value = torch.tensor(value).float()
try:
torch.testing.assert_allclose(tensor, value, **kwarg)
except AssertionError as e:
self.fail(self._formatMessage(msg, str(e)))

View File

@ -0,0 +1,108 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import tempfile
import mmengine
from mmengine.config import Config, DictAction
from mmengine.evaluator import Evaluator
from mmengine.runner import Runner
from mmcls.evaluation import ConfusionMatrix
from mmcls.registry import DATASETS
from mmcls.utils import register_all_modules
def parse_args():
parser = argparse.ArgumentParser(
description='Eval a checkpoint and draw the confusion matrix.')
parser.add_argument('config', help='test config file path')
parser.add_argument(
'ckpt_or_result',
type=str,
help='The checkpoint file (.pth) or '
'dumpped predictions pickle file (.pkl).')
parser.add_argument('--out', help='the file to save the confusion matrix.')
parser.add_argument(
'--show',
action='store_true',
help='whether to display the metric result by matplotlib if supports.')
parser.add_argument(
'--show-path', type=str, help='Path to save the visualization image.')
parser.add_argument(
'--include-values',
action='store_true',
help='To draw the values in the figure.')
parser.add_argument(
'--cmap',
type=str,
default='viridis',
help='The color map to use. Defaults to "viridis".')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
def main():
args = parse_args()
# register all modules in mmcls into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
# load config
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
if args.ckpt_or_result.endswith('.pth'):
# Set confusion matrix as the metric.
cfg.test_evaluator = dict(type='ConfusionMatrix')
cfg.load_from = str(args.ckpt_or_result)
with tempfile.TemporaryDirectory() as tmpdir:
cfg.work_dir = tmpdir
runner = Runner.from_cfg(cfg)
classes = runner.test_loop.dataloader.dataset.metainfo.get(
'classes')
cm = runner.test()['confusion_matrix/result']
else:
predictions = mmengine.load(args.ckpt_or_result)
evaluator = Evaluator(ConfusionMatrix())
metrics = evaluator.offline_evaluate(predictions, None)
cm = metrics['confusion_matrix/result']
try:
# Try to build the dataset.
dataset = DATASETS.build({
**cfg.test_dataloader.dataset, 'pipeline': []
})
classes = dataset.metainfo.get('classes')
except Exception:
classes = None
if args.out is not None:
mmengine.dump(cm, args.out)
if args.show or args.show_path is not None:
fig = ConfusionMatrix.plot(
cm,
show=args.show,
classes=classes,
include_values=args.include_values,
cmap=args.cmap)
if args.show_path is not None:
fig.savefig(args.show_path)
print(f'The confusion matrix is saved at {args.show_path}.')
if __name__ == '__main__':
main()