diff --git a/mmcls/apis/inference.py b/mmcls/apis/inference.py index 94bf2e70..5907328f 100644 --- a/mmcls/apis/inference.py +++ b/mmcls/apis/inference.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings -import matplotlib.pyplot as plt import mmcv import numpy as np import torch @@ -90,7 +89,7 @@ def inference_model(model, img): return result -def show_result_pyplot(model, img, result, fig_size=(15, 10)): +def show_result_pyplot(model, img, result, fig_size=(15, 10), wait_time=0): """Visualize the classification results on the image. Args: @@ -98,10 +97,11 @@ def show_result_pyplot(model, img, result, fig_size=(15, 10)): img (str or np.ndarray): Image filename or loaded image. result (list): The classification result. fig_size (tuple): Figure size of the pyplot figure. + Defaults to (15, 10). + wait_time (int): How many seconds to display the image. + Defaults to 0. """ if hasattr(model, 'module'): model = model.module - img = model.show_result(img, result, show=False) - plt.figure(figsize=fig_size) - plt.imshow(mmcv.bgr2rgb(img)) - plt.show() + model.show_result( + img, result, show=True, fig_size=fig_size, wait_time=wait_time) diff --git a/mmcls/core/visualization/__init__.py b/mmcls/core/visualization/__init__.py new file mode 100644 index 00000000..ea89d749 --- /dev/null +++ b/mmcls/core/visualization/__init__.py @@ -0,0 +1,3 @@ +from .image import color_val_matplotlib, imshow_infos + +__all__ = ['imshow_infos', 'color_val_matplotlib'] diff --git a/mmcls/core/visualization/image.py b/mmcls/core/visualization/image.py new file mode 100644 index 00000000..9076eddb --- /dev/null +++ b/mmcls/core/visualization/image.py @@ -0,0 +1,130 @@ +import matplotlib.pyplot as plt +import mmcv +import numpy as np + +# A small value +EPS = 1e-2 + + +def color_val_matplotlib(color): + """Convert various input in BGR order to normalized RGB matplotlib color + tuples, + + Args: + color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Color inputs + + Returns: + tuple[float]: A tuple of 3 normalized floats indicating RGB channels. + """ + color = mmcv.color_val(color) + color = [color / 255 for color in color[::-1]] + return tuple(color) + + +def imshow_infos(img, + infos, + text_color='white', + font_size=26, + row_width=20, + win_name='', + show=True, + fig_size=(15, 10), + wait_time=0, + out_file=None): + """Show image with extra infomation. + + Args: + img (str | ndarray): The image to be displayed. + infos (dict): Extra infos to display in the image. + text_color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Extra infos + display color. Defaults to 'white'. + font_size (int): Extra infos display font size. Defaults to 26. + row_width (int): width between each row of results on the image. + win_name (str): The image title. Defaults to '' + show (bool): Whether to show the image. Defaults to True. + fig_size (tuple): Image show figure size. Defaults to (15, 10). + wait_time (int): How many seconds to display the image. Defaults to 0. + out_file (Optional[str]): The filename to write the image. + Defaults to None. + + Returns: + np.ndarray: The image with extra infomations. + """ + img = mmcv.imread(img).astype(np.uint8) + + x, y = 3, row_width // 2 + text_color = color_val_matplotlib(text_color) + + img = mmcv.bgr2rgb(img) + width, height = img.shape[1], img.shape[0] + img = np.ascontiguousarray(img) + + # A proper dpi for image save with default font size. + fig = plt.figure(win_name, frameon=False, figsize=fig_size, dpi=36) + plt.title(win_name) + canvas = fig.canvas + dpi = fig.get_dpi() + # add a small EPS to avoid precision lost due to matplotlib's truncation + # (https://github.com/matplotlib/matplotlib/issues/15363) + fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi) + + # remove white edges by set subplot margin + plt.subplots_adjust(left=0, right=1, bottom=0, top=1) + ax = plt.gca() + ax.axis('off') + + for k, v in infos.items(): + if isinstance(v, float): + v = f'{v:.2f}' + label_text = f'{k}: {v}' + ax.text( + x, + y, + f'{label_text}', + bbox={ + 'facecolor': 'black', + 'alpha': 0.7, + 'pad': 0.2, + 'edgecolor': 'none', + 'boxstyle': 'round' + }, + color=text_color, + fontsize=font_size, + family='monospace', + verticalalignment='top', + horizontalalignment='left') + y += row_width + + plt.imshow(img) + stream, _ = canvas.print_to_buffer() + buffer = np.frombuffer(stream, dtype='uint8') + img_rgba = buffer.reshape(height, width, 4) + rgb, _ = np.split(img_rgba, [3], axis=2) + img = rgb.astype('uint8') + img = mmcv.rgb2bgr(img) + + if show: + # Matplotlib will adjust text size depends on window size and image + # aspect ratio. It's hard to get, so here we set an adaptive dpi + # according to screen height. 20 here is an empirical parameter. + fig_manager = plt.get_current_fig_manager() + if hasattr(fig_manager, 'window'): + # Figure manager doesn't have window if no screen. + screen_dpi = fig_manager.window.winfo_screenheight() // 20 + fig.set_dpi(screen_dpi) + + # We do not use cv2 for display because in some cases, opencv will + # conflict with Qt, it will output a warning: Current thread + # is not the object's thread. You can refer to + # https://github.com/opencv/opencv-python/issues/46 for details + if wait_time == 0: + plt.show() + else: + plt.show(block=False) + plt.pause(wait_time) + if out_file is not None: + mmcv.imwrite(img, out_file) + + plt.close() + + return img diff --git a/mmcls/models/classifiers/base.py b/mmcls/models/classifiers/base.py index 416e8b6e..725c4340 100644 --- a/mmcls/models/classifiers/base.py +++ b/mmcls/models/classifiers/base.py @@ -3,13 +3,13 @@ import warnings from abc import ABCMeta, abstractmethod from collections import OrderedDict -import cv2 import mmcv import torch import torch.distributed as dist -from mmcv import color_val from mmcv.runner import BaseModule +from mmcls.core.visualization import imshow_infos + # TODO import `auto_fp16` from mmcv and delete them from mmcls try: from mmcv.runner import auto_fp16 @@ -169,10 +169,11 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta): def show_result(self, img, result, - text_color='green', + text_color='white', font_scale=0.5, row_width=20, show=False, + fig_size=(15, 10), win_name='', wait_time=0, out_file=None): @@ -186,39 +187,29 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta): row_width (int): width between each row of results on the image. show (bool): Whether to show the image. Default: False. + fig_size (tuple): Image show figure size. Defaults to (15, 10). win_name (str): The window name. - wait_time (int): Value of waitKey param. - Default: 0. + wait_time (int): How many seconds to display the image. + Defaults to 0. out_file (str or None): The filename to write the image. Default: None. Returns: - img (ndarray): Only if not `show` or `out_file` + img (ndarray): Image with overlayed results. """ img = mmcv.imread(img) img = img.copy() - # write results on left-top of the image - x, y = 0, row_width - text_color = color_val(text_color) - for k, v in result.items(): - if isinstance(v, float): - v = f'{v:.2f}' - label_text = f'{k}: {v}' - cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX, - font_scale, text_color) - y += row_width + img = imshow_infos( + img, + result, + text_color=text_color, + font_size=int(font_scale * 50), + row_width=row_width, + win_name=win_name, + show=show, + fig_size=fig_size, + wait_time=wait_time, + out_file=out_file) - # if out_file specified, do not show image in window - if out_file is not None: - show = False - - if show: - mmcv.imshow(img, win_name, wait_time) - if out_file is not None: - mmcv.imwrite(img, out_file) - - if not (show or out_file): - warnings.warn('show==False and out_file is not specified, only ' - 'result image will be returned') - return img + return img diff --git a/setup.cfg b/setup.cfg index 262a3f11..27a9c11a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = mmcls -known_third_party = PIL,cv2,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,seaborn,torch,torchvision,ts +known_third_party = PIL,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,seaborn,torch,torchvision,ts no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tests/test_models/test_classifiers.py b/tests/test_models/test_classifiers.py index 410b2698..7f6e24ca 100644 --- a/tests/test_models/test_classifiers.py +++ b/tests/test_models/test_classifiers.py @@ -2,7 +2,6 @@ import os.path as osp import tempfile from copy import deepcopy -from unittest.mock import patch import numpy as np import pytest @@ -85,16 +84,6 @@ def test_image_classifier(): model.show_result(img, result, out_file=out_file) assert osp.exists(out_file) - def save_show(_, *args): - out_path = osp.join(tmpdir, '_'.join([str(arg) for arg in args])) - with open(out_path, 'w') as f: - f.write('test') - - with patch('mmcv.imshow', save_show): - model.show_result( - img, result, show=True, win_name='img', wait_time=5) - assert osp.exists(osp.join(tmpdir, 'img_5')) - def test_image_classifier_with_mixup(): # Test mixup in ImageClassifier diff --git a/tests/test_utils/test_visualization.py b/tests/test_utils/test_visualization.py new file mode 100644 index 00000000..b30c0089 --- /dev/null +++ b/tests/test_utils/test_visualization.py @@ -0,0 +1,90 @@ +# Copyright (c) Open-MMLab. All rights reserved. +import os +import os.path as osp +import shutil +import tempfile +from unittest.mock import Mock, patch + +import mmcv +import numpy as np +import pytest + +from mmcls.core import visualization as vis + + +def test_color(): + assert vis.color_val_matplotlib(mmcv.Color.blue) == (0., 0., 1.) + assert vis.color_val_matplotlib('green') == (0., 1., 0.) + assert vis.color_val_matplotlib((1, 2, 3)) == (3 / 255, 2 / 255, 1 / 255) + assert vis.color_val_matplotlib(100) == (100 / 255, 100 / 255, 100 / 255) + assert vis.color_val_matplotlib(np.zeros(3, dtype=int)) == (0., 0., 0.) + # forbid white color + with pytest.raises(TypeError): + vis.color_val_matplotlib([255, 255, 255]) + # forbid float + with pytest.raises(TypeError): + vis.color_val_matplotlib(1.0) + # overflowed + with pytest.raises(AssertionError): + vis.color_val_matplotlib((0, 0, 500)) + + +def test_imshow_infos(): + tmp_dir = osp.join(tempfile.gettempdir(), 'infos_image') + tmp_filename = osp.join(tmp_dir, 'image.jpg') + + image = np.ones((10, 10, 3), np.uint8) + result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98} + out_image = vis.imshow_infos( + image, result, out_file=tmp_filename, show=False) + assert osp.isfile(tmp_filename) + assert image.shape == out_image.shape + assert not np.allclose(image, out_image) + os.remove(tmp_filename) + + # test grayscale images + image = np.ones((10, 10), np.uint8) + result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98} + out_image = vis.imshow_infos( + image, result, out_file=tmp_filename, show=False) + assert osp.isfile(tmp_filename) + assert image.shape == out_image.shape[:2] + os.remove(tmp_filename) + + # test show=True + image = np.ones((10, 10, 3), np.uint8) + result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98} + + def save_args(*args, **kwargs): + args_list = ['args'] + args_list += [ + str(arg) for arg in args if isinstance(arg, (str, bool, int)) + ] + args_list += [ + f'{k}-{v}' for k, v in kwargs.items() + if isinstance(v, (str, bool, int)) + ] + out_path = osp.join(tmp_dir, '_'.join(args_list)) + with open(out_path, 'w') as f: + f.write('test') + + with patch('matplotlib.pyplot.show', save_args), \ + patch('matplotlib.pyplot.pause', save_args): + vis.imshow_infos(image, result, show=True, wait_time=5) + assert osp.exists(osp.join(tmp_dir, 'args_block-False')) + assert osp.exists(osp.join(tmp_dir, 'args_5')) + + vis.imshow_infos(image, result, show=True, wait_time=0) + assert osp.exists(osp.join(tmp_dir, 'args')) + + # test adaptive dpi + def mock_fig_manager(): + fig_manager = Mock() + fig_manager.window.winfo_screenheight = Mock(return_value=1440) + return fig_manager + + with patch('matplotlib.pyplot.get_current_fig_manager', + mock_fig_manager), patch('matplotlib.pyplot.show'): + vis.imshow_infos(image, result, show=True) + + shutil.rmtree(tmp_dir)