[Enhance] Better result visualization (#419)

* Imporve result visualization to support wait time and change the backend
to matplotlib.

* Add unit test for visualization

* Add adaptive dpi function

* Rename `imshow_cls_result` to `imshow_infos`.

* Support str in `imshow_infos`

* Improve docstring.
pull/426/head
Ma Zerun 2021-08-31 10:50:28 +08:00 committed by GitHub
parent 192b79eea0
commit 5383787512
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 250 additions and 47 deletions

View File

@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import matplotlib.pyplot as plt
import mmcv
import numpy as np
import torch
@ -90,7 +89,7 @@ def inference_model(model, img):
return result
def show_result_pyplot(model, img, result, fig_size=(15, 10)):
def show_result_pyplot(model, img, result, fig_size=(15, 10), wait_time=0):
"""Visualize the classification results on the image.
Args:
@ -98,10 +97,11 @@ def show_result_pyplot(model, img, result, fig_size=(15, 10)):
img (str or np.ndarray): Image filename or loaded image.
result (list): The classification result.
fig_size (tuple): Figure size of the pyplot figure.
Defaults to (15, 10).
wait_time (int): How many seconds to display the image.
Defaults to 0.
"""
if hasattr(model, 'module'):
model = model.module
img = model.show_result(img, result, show=False)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
plt.show()
model.show_result(
img, result, show=True, fig_size=fig_size, wait_time=wait_time)

View File

@ -0,0 +1,3 @@
from .image import color_val_matplotlib, imshow_infos
__all__ = ['imshow_infos', 'color_val_matplotlib']

View File

@ -0,0 +1,130 @@
import matplotlib.pyplot as plt
import mmcv
import numpy as np
# A small value
EPS = 1e-2
def color_val_matplotlib(color):
"""Convert various input in BGR order to normalized RGB matplotlib color
tuples,
Args:
color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Color inputs
Returns:
tuple[float]: A tuple of 3 normalized floats indicating RGB channels.
"""
color = mmcv.color_val(color)
color = [color / 255 for color in color[::-1]]
return tuple(color)
def imshow_infos(img,
infos,
text_color='white',
font_size=26,
row_width=20,
win_name='',
show=True,
fig_size=(15, 10),
wait_time=0,
out_file=None):
"""Show image with extra infomation.
Args:
img (str | ndarray): The image to be displayed.
infos (dict): Extra infos to display in the image.
text_color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Extra infos
display color. Defaults to 'white'.
font_size (int): Extra infos display font size. Defaults to 26.
row_width (int): width between each row of results on the image.
win_name (str): The image title. Defaults to ''
show (bool): Whether to show the image. Defaults to True.
fig_size (tuple): Image show figure size. Defaults to (15, 10).
wait_time (int): How many seconds to display the image. Defaults to 0.
out_file (Optional[str]): The filename to write the image.
Defaults to None.
Returns:
np.ndarray: The image with extra infomations.
"""
img = mmcv.imread(img).astype(np.uint8)
x, y = 3, row_width // 2
text_color = color_val_matplotlib(text_color)
img = mmcv.bgr2rgb(img)
width, height = img.shape[1], img.shape[0]
img = np.ascontiguousarray(img)
# A proper dpi for image save with default font size.
fig = plt.figure(win_name, frameon=False, figsize=fig_size, dpi=36)
plt.title(win_name)
canvas = fig.canvas
dpi = fig.get_dpi()
# add a small EPS to avoid precision lost due to matplotlib's truncation
# (https://github.com/matplotlib/matplotlib/issues/15363)
fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi)
# remove white edges by set subplot margin
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
ax = plt.gca()
ax.axis('off')
for k, v in infos.items():
if isinstance(v, float):
v = f'{v:.2f}'
label_text = f'{k}: {v}'
ax.text(
x,
y,
f'{label_text}',
bbox={
'facecolor': 'black',
'alpha': 0.7,
'pad': 0.2,
'edgecolor': 'none',
'boxstyle': 'round'
},
color=text_color,
fontsize=font_size,
family='monospace',
verticalalignment='top',
horizontalalignment='left')
y += row_width
plt.imshow(img)
stream, _ = canvas.print_to_buffer()
buffer = np.frombuffer(stream, dtype='uint8')
img_rgba = buffer.reshape(height, width, 4)
rgb, _ = np.split(img_rgba, [3], axis=2)
img = rgb.astype('uint8')
img = mmcv.rgb2bgr(img)
if show:
# Matplotlib will adjust text size depends on window size and image
# aspect ratio. It's hard to get, so here we set an adaptive dpi
# according to screen height. 20 here is an empirical parameter.
fig_manager = plt.get_current_fig_manager()
if hasattr(fig_manager, 'window'):
# Figure manager doesn't have window if no screen.
screen_dpi = fig_manager.window.winfo_screenheight() // 20
fig.set_dpi(screen_dpi)
# We do not use cv2 for display because in some cases, opencv will
# conflict with Qt, it will output a warning: Current thread
# is not the object's thread. You can refer to
# https://github.com/opencv/opencv-python/issues/46 for details
if wait_time == 0:
plt.show()
else:
plt.show(block=False)
plt.pause(wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)
plt.close()
return img

View File

@ -3,13 +3,13 @@ import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
import cv2
import mmcv
import torch
import torch.distributed as dist
from mmcv import color_val
from mmcv.runner import BaseModule
from mmcls.core.visualization import imshow_infos
# TODO import `auto_fp16` from mmcv and delete them from mmcls
try:
from mmcv.runner import auto_fp16
@ -169,10 +169,11 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
def show_result(self,
img,
result,
text_color='green',
text_color='white',
font_scale=0.5,
row_width=20,
show=False,
fig_size=(15, 10),
win_name='',
wait_time=0,
out_file=None):
@ -186,39 +187,29 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
row_width (int): width between each row of results on the image.
show (bool): Whether to show the image.
Default: False.
fig_size (tuple): Image show figure size. Defaults to (15, 10).
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
wait_time (int): How many seconds to display the image.
Defaults to 0.
out_file (str or None): The filename to write the image.
Default: None.
Returns:
img (ndarray): Only if not `show` or `out_file`
img (ndarray): Image with overlayed results.
"""
img = mmcv.imread(img)
img = img.copy()
# write results on left-top of the image
x, y = 0, row_width
text_color = color_val(text_color)
for k, v in result.items():
if isinstance(v, float):
v = f'{v:.2f}'
label_text = f'{k}: {v}'
cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
font_scale, text_color)
y += row_width
img = imshow_infos(
img,
result,
text_color=text_color,
font_size=int(font_scale * 50),
row_width=row_width,
win_name=win_name,
show=show,
fig_size=fig_size,
wait_time=wait_time,
out_file=out_file)
# if out_file specified, do not show image in window
if out_file is not None:
show = False
if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)
if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img
return img

View File

@ -14,6 +14,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmcls
known_third_party = PIL,cv2,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,seaborn,torch,torchvision,ts
known_third_party = PIL,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,seaborn,torch,torchvision,ts
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

View File

@ -2,7 +2,6 @@
import os.path as osp
import tempfile
from copy import deepcopy
from unittest.mock import patch
import numpy as np
import pytest
@ -85,16 +84,6 @@ def test_image_classifier():
model.show_result(img, result, out_file=out_file)
assert osp.exists(out_file)
def save_show(_, *args):
out_path = osp.join(tmpdir, '_'.join([str(arg) for arg in args]))
with open(out_path, 'w') as f:
f.write('test')
with patch('mmcv.imshow', save_show):
model.show_result(
img, result, show=True, win_name='img', wait_time=5)
assert osp.exists(osp.join(tmpdir, 'img_5'))
def test_image_classifier_with_mixup():
# Test mixup in ImageClassifier

View File

@ -0,0 +1,90 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os
import os.path as osp
import shutil
import tempfile
from unittest.mock import Mock, patch
import mmcv
import numpy as np
import pytest
from mmcls.core import visualization as vis
def test_color():
assert vis.color_val_matplotlib(mmcv.Color.blue) == (0., 0., 1.)
assert vis.color_val_matplotlib('green') == (0., 1., 0.)
assert vis.color_val_matplotlib((1, 2, 3)) == (3 / 255, 2 / 255, 1 / 255)
assert vis.color_val_matplotlib(100) == (100 / 255, 100 / 255, 100 / 255)
assert vis.color_val_matplotlib(np.zeros(3, dtype=int)) == (0., 0., 0.)
# forbid white color
with pytest.raises(TypeError):
vis.color_val_matplotlib([255, 255, 255])
# forbid float
with pytest.raises(TypeError):
vis.color_val_matplotlib(1.0)
# overflowed
with pytest.raises(AssertionError):
vis.color_val_matplotlib((0, 0, 500))
def test_imshow_infos():
tmp_dir = osp.join(tempfile.gettempdir(), 'infos_image')
tmp_filename = osp.join(tmp_dir, 'image.jpg')
image = np.ones((10, 10, 3), np.uint8)
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
out_image = vis.imshow_infos(
image, result, out_file=tmp_filename, show=False)
assert osp.isfile(tmp_filename)
assert image.shape == out_image.shape
assert not np.allclose(image, out_image)
os.remove(tmp_filename)
# test grayscale images
image = np.ones((10, 10), np.uint8)
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
out_image = vis.imshow_infos(
image, result, out_file=tmp_filename, show=False)
assert osp.isfile(tmp_filename)
assert image.shape == out_image.shape[:2]
os.remove(tmp_filename)
# test show=True
image = np.ones((10, 10, 3), np.uint8)
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
def save_args(*args, **kwargs):
args_list = ['args']
args_list += [
str(arg) for arg in args if isinstance(arg, (str, bool, int))
]
args_list += [
f'{k}-{v}' for k, v in kwargs.items()
if isinstance(v, (str, bool, int))
]
out_path = osp.join(tmp_dir, '_'.join(args_list))
with open(out_path, 'w') as f:
f.write('test')
with patch('matplotlib.pyplot.show', save_args), \
patch('matplotlib.pyplot.pause', save_args):
vis.imshow_infos(image, result, show=True, wait_time=5)
assert osp.exists(osp.join(tmp_dir, 'args_block-False'))
assert osp.exists(osp.join(tmp_dir, 'args_5'))
vis.imshow_infos(image, result, show=True, wait_time=0)
assert osp.exists(osp.join(tmp_dir, 'args'))
# test adaptive dpi
def mock_fig_manager():
fig_manager = Mock()
fig_manager.window.winfo_screenheight = Mock(return_value=1440)
return fig_manager
with patch('matplotlib.pyplot.get_current_fig_manager',
mock_fig_manager), patch('matplotlib.pyplot.show'):
vis.imshow_infos(image, result, show=True)
shutil.rmtree(tmp_dir)