[Fix] Update confusion_matrix.py (#3291)

## Motivation



## Modification

The confusion_matrix.py is not compatible with the current version of
mmseg.

---------

Co-authored-by: xiexinch <xiexinch@outlook.com>
This commit is contained in:
Gorgeous 2023-08-31 12:53:33 +08:00 committed by GitHub
parent 72e20a8854
commit ebd5695104
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,10 +5,14 @@ import os
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
from mmengine import Config, DictAction
from mmengine.utils import ProgressBar, load
from mmengine.config import Config, DictAction
from mmengine.registry import init_default_scope
from mmengine.utils import mkdir_or_exist, progressbar
from PIL import Image
from mmseg.datasets import build_dataset
from mmseg.registry import DATASETS
init_default_scope('mmseg')
def parse_args():
@ -16,7 +20,7 @@ def parse_args():
description='Generate confusion matrix from segmentation results')
parser.add_argument('config', help='test config file path')
parser.add_argument(
'prediction_path', help='prediction path where test .pkl result')
'prediction_path', help='prediction path where test folder result')
parser.add_argument(
'save_dir', help='directory where confusion matrix will be saved')
parser.add_argument(
@ -50,15 +54,23 @@ def calculate_confusion_matrix(dataset, results):
dataset (Dataset): Test or val dataset.
results (list[ndarray]): A list of segmentation results in each image.
"""
n = len(dataset.CLASSES)
n = len(dataset.METAINFO['classes'])
confusion_matrix = np.zeros(shape=[n, n])
assert len(dataset) == len(results)
prog_bar = ProgressBar(len(results))
ignore_index = dataset.ignore_index
reduce_zero_label = dataset.reduce_zero_label
prog_bar = progressbar.ProgressBar(len(results))
for idx, per_img_res in enumerate(results):
res_segm = per_img_res
gt_segm = dataset.get_gt_seg_map_by_idx(idx)
gt_segm = dataset[idx]['data_samples'] \
.gt_sem_seg.data.squeeze().numpy().astype(np.uint8)
gt_segm, res_segm = gt_segm.flatten(), res_segm.flatten()
if reduce_zero_label:
gt_segm = gt_segm - 1
to_ignore = gt_segm == ignore_index
gt_segm, res_segm = gt_segm[~to_ignore], res_segm[~to_ignore]
inds = n * gt_segm + res_segm
inds = inds.flatten()
mat = np.bincount(inds, minlength=n**2).reshape(n, n)
confusion_matrix += mat
prog_bar.update()
@ -70,7 +82,7 @@ def plot_confusion_matrix(confusion_matrix,
save_dir=None,
show=True,
title='Normalized Confusion Matrix',
color_theme='winter'):
color_theme='OrRd'):
"""Draw confusion matrix with matplotlib.
Args:
@ -89,14 +101,15 @@ def plot_confusion_matrix(confusion_matrix,
num_classes = len(labels)
fig, ax = plt.subplots(
figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=180)
figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=300)
cmap = plt.get_cmap(color_theme)
im = ax.imshow(confusion_matrix, cmap=cmap)
plt.colorbar(mappable=im, ax=ax)
colorbar = plt.colorbar(mappable=im, ax=ax)
colorbar.ax.tick_params(labelsize=20) # 设置 colorbar 标签的字体大小
title_font = {'weight': 'bold', 'size': 12}
title_font = {'weight': 'bold', 'size': 20}
ax.set_title(title, fontdict=title_font)
label_font = {'size': 10}
label_font = {'size': 40}
plt.ylabel('Ground Truth Label', fontdict=label_font)
plt.xlabel('Prediction Label', fontdict=label_font)
@ -116,8 +129,8 @@ def plot_confusion_matrix(confusion_matrix,
# draw label
ax.set_xticks(np.arange(num_classes))
ax.set_yticks(np.arange(num_classes))
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)
ax.set_xticklabels(labels, fontsize=20)
ax.set_yticklabels(labels, fontsize=20)
ax.tick_params(
axis='x', bottom=False, top=True, labelbottom=False, labeltop=True)
@ -135,13 +148,14 @@ def plot_confusion_matrix(confusion_matrix,
) if not np.isnan(confusion_matrix[i, j]) else -1),
ha='center',
va='center',
color='w',
size=7)
color='k',
size=20)
ax.set_ylim(len(confusion_matrix) - 0.5, -0.5) # matplotlib>3.1.1
fig.tight_layout()
if save_dir is not None:
mkdir_or_exist(save_dir)
plt.savefig(
os.path.join(save_dir, 'confusion_matrix.png'), format='png')
if show:
@ -155,7 +169,12 @@ def main():
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
results = load(args.prediction_path)
results = []
for img in sorted(os.listdir(args.prediction_path)):
img = os.path.join(args.prediction_path, img)
image = Image.open(img)
image = np.copy(image)
results.append(image)
assert isinstance(results, list)
if isinstance(results[0], np.ndarray):
@ -163,17 +182,11 @@ def main():
else:
raise TypeError('invalid type of prediction results')
if isinstance(cfg.data.test, dict):
cfg.data.test.test_mode = True
elif isinstance(cfg.data.test, list):
for ds_cfg in cfg.data.test:
ds_cfg.test_mode = True
dataset = build_dataset(cfg.data.test)
dataset = DATASETS.build(cfg.test_dataloader.dataset)
confusion_matrix = calculate_confusion_matrix(dataset, results)
plot_confusion_matrix(
confusion_matrix,
dataset.CLASSES,
dataset.METAINFO['classes'],
save_dir=args.save_dir,
show=args.show,
title=args.title,