diff --git a/tools/analysis_tools/confusion_matrix.py b/tools/analysis_tools/confusion_matrix.py index 9a87bc14c..39756cdfd 100644 --- a/tools/analysis_tools/confusion_matrix.py +++ b/tools/analysis_tools/confusion_matrix.py @@ -5,10 +5,14 @@ import os import matplotlib.pyplot as plt import numpy as np from matplotlib.ticker import MultipleLocator -from mmengine import Config, DictAction -from mmengine.utils import ProgressBar, load +from mmengine.config import Config, DictAction +from mmengine.registry import init_default_scope +from mmengine.utils import mkdir_or_exist, progressbar +from PIL import Image -from mmseg.datasets import build_dataset +from mmseg.registry import DATASETS + +init_default_scope('mmseg') def parse_args(): @@ -16,7 +20,7 @@ def parse_args(): description='Generate confusion matrix from segmentation results') parser.add_argument('config', help='test config file path') parser.add_argument( - 'prediction_path', help='prediction path where test .pkl result') + 'prediction_path', help='prediction path where test folder result') parser.add_argument( 'save_dir', help='directory where confusion matrix will be saved') parser.add_argument( @@ -50,15 +54,23 @@ def calculate_confusion_matrix(dataset, results): dataset (Dataset): Test or val dataset. results (list[ndarray]): A list of segmentation results in each image. """ - n = len(dataset.CLASSES) + n = len(dataset.METAINFO['classes']) confusion_matrix = np.zeros(shape=[n, n]) assert len(dataset) == len(results) - prog_bar = ProgressBar(len(results)) + ignore_index = dataset.ignore_index + reduce_zero_label = dataset.reduce_zero_label + prog_bar = progressbar.ProgressBar(len(results)) for idx, per_img_res in enumerate(results): res_segm = per_img_res - gt_segm = dataset.get_gt_seg_map_by_idx(idx) + gt_segm = dataset[idx]['data_samples'] \ + .gt_sem_seg.data.squeeze().numpy().astype(np.uint8) + gt_segm, res_segm = gt_segm.flatten(), res_segm.flatten() + if reduce_zero_label: + gt_segm = gt_segm - 1 + to_ignore = gt_segm == ignore_index + + gt_segm, res_segm = gt_segm[~to_ignore], res_segm[~to_ignore] inds = n * gt_segm + res_segm - inds = inds.flatten() mat = np.bincount(inds, minlength=n**2).reshape(n, n) confusion_matrix += mat prog_bar.update() @@ -70,7 +82,7 @@ def plot_confusion_matrix(confusion_matrix, save_dir=None, show=True, title='Normalized Confusion Matrix', - color_theme='winter'): + color_theme='OrRd'): """Draw confusion matrix with matplotlib. Args: @@ -89,14 +101,15 @@ def plot_confusion_matrix(confusion_matrix, num_classes = len(labels) fig, ax = plt.subplots( - figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=180) + figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=300) cmap = plt.get_cmap(color_theme) im = ax.imshow(confusion_matrix, cmap=cmap) - plt.colorbar(mappable=im, ax=ax) + colorbar = plt.colorbar(mappable=im, ax=ax) + colorbar.ax.tick_params(labelsize=20) # 设置 colorbar 标签的字体大小 - title_font = {'weight': 'bold', 'size': 12} + title_font = {'weight': 'bold', 'size': 20} ax.set_title(title, fontdict=title_font) - label_font = {'size': 10} + label_font = {'size': 40} plt.ylabel('Ground Truth Label', fontdict=label_font) plt.xlabel('Prediction Label', fontdict=label_font) @@ -116,8 +129,8 @@ def plot_confusion_matrix(confusion_matrix, # draw label ax.set_xticks(np.arange(num_classes)) ax.set_yticks(np.arange(num_classes)) - ax.set_xticklabels(labels) - ax.set_yticklabels(labels) + ax.set_xticklabels(labels, fontsize=20) + ax.set_yticklabels(labels, fontsize=20) ax.tick_params( axis='x', bottom=False, top=True, labelbottom=False, labeltop=True) @@ -135,13 +148,14 @@ def plot_confusion_matrix(confusion_matrix, ) if not np.isnan(confusion_matrix[i, j]) else -1), ha='center', va='center', - color='w', - size=7) + color='k', + size=20) ax.set_ylim(len(confusion_matrix) - 0.5, -0.5) # matplotlib>3.1.1 fig.tight_layout() if save_dir is not None: + mkdir_or_exist(save_dir) plt.savefig( os.path.join(save_dir, 'confusion_matrix.png'), format='png') if show: @@ -155,7 +169,12 @@ def main(): if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) - results = load(args.prediction_path) + results = [] + for img in sorted(os.listdir(args.prediction_path)): + img = os.path.join(args.prediction_path, img) + image = Image.open(img) + image = np.copy(image) + results.append(image) assert isinstance(results, list) if isinstance(results[0], np.ndarray): @@ -163,17 +182,11 @@ def main(): else: raise TypeError('invalid type of prediction results') - if isinstance(cfg.data.test, dict): - cfg.data.test.test_mode = True - elif isinstance(cfg.data.test, list): - for ds_cfg in cfg.data.test: - ds_cfg.test_mode = True - - dataset = build_dataset(cfg.data.test) + dataset = DATASETS.build(cfg.test_dataloader.dataset) confusion_matrix = calculate_confusion_matrix(dataset, results) plot_confusion_matrix( confusion_matrix, - dataset.CLASSES, + dataset.METAINFO['classes'], save_dir=args.save_dir, show=args.show, title=args.title,