[Enhance] Better result visualization (#419)
* Imporve result visualization to support wait time and change the backend to matplotlib. * Add unit test for visualization * Add adaptive dpi function * Rename `imshow_cls_result` to `imshow_infos`. * Support str in `imshow_infos` * Improve docstring.pull/426/head
parent
192b79eea0
commit
5383787512
|
@ -1,7 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -90,7 +89,7 @@ def inference_model(model, img):
|
|||
return result
|
||||
|
||||
|
||||
def show_result_pyplot(model, img, result, fig_size=(15, 10)):
|
||||
def show_result_pyplot(model, img, result, fig_size=(15, 10), wait_time=0):
|
||||
"""Visualize the classification results on the image.
|
||||
|
||||
Args:
|
||||
|
@ -98,10 +97,11 @@ def show_result_pyplot(model, img, result, fig_size=(15, 10)):
|
|||
img (str or np.ndarray): Image filename or loaded image.
|
||||
result (list): The classification result.
|
||||
fig_size (tuple): Figure size of the pyplot figure.
|
||||
Defaults to (15, 10).
|
||||
wait_time (int): How many seconds to display the image.
|
||||
Defaults to 0.
|
||||
"""
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
img = model.show_result(img, result, show=False)
|
||||
plt.figure(figsize=fig_size)
|
||||
plt.imshow(mmcv.bgr2rgb(img))
|
||||
plt.show()
|
||||
model.show_result(
|
||||
img, result, show=True, fig_size=fig_size, wait_time=wait_time)
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from .image import color_val_matplotlib, imshow_infos
|
||||
|
||||
__all__ = ['imshow_infos', 'color_val_matplotlib']
|
|
@ -0,0 +1,130 @@
|
|||
import matplotlib.pyplot as plt
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
# A small value
|
||||
EPS = 1e-2
|
||||
|
||||
|
||||
def color_val_matplotlib(color):
|
||||
"""Convert various input in BGR order to normalized RGB matplotlib color
|
||||
tuples,
|
||||
|
||||
Args:
|
||||
color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Color inputs
|
||||
|
||||
Returns:
|
||||
tuple[float]: A tuple of 3 normalized floats indicating RGB channels.
|
||||
"""
|
||||
color = mmcv.color_val(color)
|
||||
color = [color / 255 for color in color[::-1]]
|
||||
return tuple(color)
|
||||
|
||||
|
||||
def imshow_infos(img,
|
||||
infos,
|
||||
text_color='white',
|
||||
font_size=26,
|
||||
row_width=20,
|
||||
win_name='',
|
||||
show=True,
|
||||
fig_size=(15, 10),
|
||||
wait_time=0,
|
||||
out_file=None):
|
||||
"""Show image with extra infomation.
|
||||
|
||||
Args:
|
||||
img (str | ndarray): The image to be displayed.
|
||||
infos (dict): Extra infos to display in the image.
|
||||
text_color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Extra infos
|
||||
display color. Defaults to 'white'.
|
||||
font_size (int): Extra infos display font size. Defaults to 26.
|
||||
row_width (int): width between each row of results on the image.
|
||||
win_name (str): The image title. Defaults to ''
|
||||
show (bool): Whether to show the image. Defaults to True.
|
||||
fig_size (tuple): Image show figure size. Defaults to (15, 10).
|
||||
wait_time (int): How many seconds to display the image. Defaults to 0.
|
||||
out_file (Optional[str]): The filename to write the image.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image with extra infomations.
|
||||
"""
|
||||
img = mmcv.imread(img).astype(np.uint8)
|
||||
|
||||
x, y = 3, row_width // 2
|
||||
text_color = color_val_matplotlib(text_color)
|
||||
|
||||
img = mmcv.bgr2rgb(img)
|
||||
width, height = img.shape[1], img.shape[0]
|
||||
img = np.ascontiguousarray(img)
|
||||
|
||||
# A proper dpi for image save with default font size.
|
||||
fig = plt.figure(win_name, frameon=False, figsize=fig_size, dpi=36)
|
||||
plt.title(win_name)
|
||||
canvas = fig.canvas
|
||||
dpi = fig.get_dpi()
|
||||
# add a small EPS to avoid precision lost due to matplotlib's truncation
|
||||
# (https://github.com/matplotlib/matplotlib/issues/15363)
|
||||
fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi)
|
||||
|
||||
# remove white edges by set subplot margin
|
||||
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
||||
ax = plt.gca()
|
||||
ax.axis('off')
|
||||
|
||||
for k, v in infos.items():
|
||||
if isinstance(v, float):
|
||||
v = f'{v:.2f}'
|
||||
label_text = f'{k}: {v}'
|
||||
ax.text(
|
||||
x,
|
||||
y,
|
||||
f'{label_text}',
|
||||
bbox={
|
||||
'facecolor': 'black',
|
||||
'alpha': 0.7,
|
||||
'pad': 0.2,
|
||||
'edgecolor': 'none',
|
||||
'boxstyle': 'round'
|
||||
},
|
||||
color=text_color,
|
||||
fontsize=font_size,
|
||||
family='monospace',
|
||||
verticalalignment='top',
|
||||
horizontalalignment='left')
|
||||
y += row_width
|
||||
|
||||
plt.imshow(img)
|
||||
stream, _ = canvas.print_to_buffer()
|
||||
buffer = np.frombuffer(stream, dtype='uint8')
|
||||
img_rgba = buffer.reshape(height, width, 4)
|
||||
rgb, _ = np.split(img_rgba, [3], axis=2)
|
||||
img = rgb.astype('uint8')
|
||||
img = mmcv.rgb2bgr(img)
|
||||
|
||||
if show:
|
||||
# Matplotlib will adjust text size depends on window size and image
|
||||
# aspect ratio. It's hard to get, so here we set an adaptive dpi
|
||||
# according to screen height. 20 here is an empirical parameter.
|
||||
fig_manager = plt.get_current_fig_manager()
|
||||
if hasattr(fig_manager, 'window'):
|
||||
# Figure manager doesn't have window if no screen.
|
||||
screen_dpi = fig_manager.window.winfo_screenheight() // 20
|
||||
fig.set_dpi(screen_dpi)
|
||||
|
||||
# We do not use cv2 for display because in some cases, opencv will
|
||||
# conflict with Qt, it will output a warning: Current thread
|
||||
# is not the object's thread. You can refer to
|
||||
# https://github.com/opencv/opencv-python/issues/46 for details
|
||||
if wait_time == 0:
|
||||
plt.show()
|
||||
else:
|
||||
plt.show(block=False)
|
||||
plt.pause(wait_time)
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(img, out_file)
|
||||
|
||||
plt.close()
|
||||
|
||||
return img
|
|
@ -3,13 +3,13 @@ import warnings
|
|||
from abc import ABCMeta, abstractmethod
|
||||
from collections import OrderedDict
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv import color_val
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmcls.core.visualization import imshow_infos
|
||||
|
||||
# TODO import `auto_fp16` from mmcv and delete them from mmcls
|
||||
try:
|
||||
from mmcv.runner import auto_fp16
|
||||
|
@ -169,10 +169,11 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
|
|||
def show_result(self,
|
||||
img,
|
||||
result,
|
||||
text_color='green',
|
||||
text_color='white',
|
||||
font_scale=0.5,
|
||||
row_width=20,
|
||||
show=False,
|
||||
fig_size=(15, 10),
|
||||
win_name='',
|
||||
wait_time=0,
|
||||
out_file=None):
|
||||
|
@ -186,39 +187,29 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
|
|||
row_width (int): width between each row of results on the image.
|
||||
show (bool): Whether to show the image.
|
||||
Default: False.
|
||||
fig_size (tuple): Image show figure size. Defaults to (15, 10).
|
||||
win_name (str): The window name.
|
||||
wait_time (int): Value of waitKey param.
|
||||
Default: 0.
|
||||
wait_time (int): How many seconds to display the image.
|
||||
Defaults to 0.
|
||||
out_file (str or None): The filename to write the image.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
img (ndarray): Only if not `show` or `out_file`
|
||||
img (ndarray): Image with overlayed results.
|
||||
"""
|
||||
img = mmcv.imread(img)
|
||||
img = img.copy()
|
||||
|
||||
# write results on left-top of the image
|
||||
x, y = 0, row_width
|
||||
text_color = color_val(text_color)
|
||||
for k, v in result.items():
|
||||
if isinstance(v, float):
|
||||
v = f'{v:.2f}'
|
||||
label_text = f'{k}: {v}'
|
||||
cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
|
||||
font_scale, text_color)
|
||||
y += row_width
|
||||
img = imshow_infos(
|
||||
img,
|
||||
result,
|
||||
text_color=text_color,
|
||||
font_size=int(font_scale * 50),
|
||||
row_width=row_width,
|
||||
win_name=win_name,
|
||||
show=show,
|
||||
fig_size=fig_size,
|
||||
wait_time=wait_time,
|
||||
out_file=out_file)
|
||||
|
||||
# if out_file specified, do not show image in window
|
||||
if out_file is not None:
|
||||
show = False
|
||||
|
||||
if show:
|
||||
mmcv.imshow(img, win_name, wait_time)
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(img, out_file)
|
||||
|
||||
if not (show or out_file):
|
||||
warnings.warn('show==False and out_file is not specified, only '
|
||||
'result image will be returned')
|
||||
return img
|
||||
return img
|
||||
|
|
|
@ -14,6 +14,6 @@ line_length = 79
|
|||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools
|
||||
known_first_party = mmcls
|
||||
known_third_party = PIL,cv2,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,seaborn,torch,torchvision,ts
|
||||
known_third_party = PIL,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,seaborn,torch,torchvision,ts
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
import os.path as osp
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -85,16 +84,6 @@ def test_image_classifier():
|
|||
model.show_result(img, result, out_file=out_file)
|
||||
assert osp.exists(out_file)
|
||||
|
||||
def save_show(_, *args):
|
||||
out_path = osp.join(tmpdir, '_'.join([str(arg) for arg in args]))
|
||||
with open(out_path, 'w') as f:
|
||||
f.write('test')
|
||||
|
||||
with patch('mmcv.imshow', save_show):
|
||||
model.show_result(
|
||||
img, result, show=True, win_name='img', wait_time=5)
|
||||
assert osp.exists(osp.join(tmpdir, 'img_5'))
|
||||
|
||||
|
||||
def test_image_classifier_with_mixup():
|
||||
# Test mixup in ImageClassifier
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mmcls.core import visualization as vis
|
||||
|
||||
|
||||
def test_color():
|
||||
assert vis.color_val_matplotlib(mmcv.Color.blue) == (0., 0., 1.)
|
||||
assert vis.color_val_matplotlib('green') == (0., 1., 0.)
|
||||
assert vis.color_val_matplotlib((1, 2, 3)) == (3 / 255, 2 / 255, 1 / 255)
|
||||
assert vis.color_val_matplotlib(100) == (100 / 255, 100 / 255, 100 / 255)
|
||||
assert vis.color_val_matplotlib(np.zeros(3, dtype=int)) == (0., 0., 0.)
|
||||
# forbid white color
|
||||
with pytest.raises(TypeError):
|
||||
vis.color_val_matplotlib([255, 255, 255])
|
||||
# forbid float
|
||||
with pytest.raises(TypeError):
|
||||
vis.color_val_matplotlib(1.0)
|
||||
# overflowed
|
||||
with pytest.raises(AssertionError):
|
||||
vis.color_val_matplotlib((0, 0, 500))
|
||||
|
||||
|
||||
def test_imshow_infos():
|
||||
tmp_dir = osp.join(tempfile.gettempdir(), 'infos_image')
|
||||
tmp_filename = osp.join(tmp_dir, 'image.jpg')
|
||||
|
||||
image = np.ones((10, 10, 3), np.uint8)
|
||||
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
|
||||
out_image = vis.imshow_infos(
|
||||
image, result, out_file=tmp_filename, show=False)
|
||||
assert osp.isfile(tmp_filename)
|
||||
assert image.shape == out_image.shape
|
||||
assert not np.allclose(image, out_image)
|
||||
os.remove(tmp_filename)
|
||||
|
||||
# test grayscale images
|
||||
image = np.ones((10, 10), np.uint8)
|
||||
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
|
||||
out_image = vis.imshow_infos(
|
||||
image, result, out_file=tmp_filename, show=False)
|
||||
assert osp.isfile(tmp_filename)
|
||||
assert image.shape == out_image.shape[:2]
|
||||
os.remove(tmp_filename)
|
||||
|
||||
# test show=True
|
||||
image = np.ones((10, 10, 3), np.uint8)
|
||||
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
|
||||
|
||||
def save_args(*args, **kwargs):
|
||||
args_list = ['args']
|
||||
args_list += [
|
||||
str(arg) for arg in args if isinstance(arg, (str, bool, int))
|
||||
]
|
||||
args_list += [
|
||||
f'{k}-{v}' for k, v in kwargs.items()
|
||||
if isinstance(v, (str, bool, int))
|
||||
]
|
||||
out_path = osp.join(tmp_dir, '_'.join(args_list))
|
||||
with open(out_path, 'w') as f:
|
||||
f.write('test')
|
||||
|
||||
with patch('matplotlib.pyplot.show', save_args), \
|
||||
patch('matplotlib.pyplot.pause', save_args):
|
||||
vis.imshow_infos(image, result, show=True, wait_time=5)
|
||||
assert osp.exists(osp.join(tmp_dir, 'args_block-False'))
|
||||
assert osp.exists(osp.join(tmp_dir, 'args_5'))
|
||||
|
||||
vis.imshow_infos(image, result, show=True, wait_time=0)
|
||||
assert osp.exists(osp.join(tmp_dir, 'args'))
|
||||
|
||||
# test adaptive dpi
|
||||
def mock_fig_manager():
|
||||
fig_manager = Mock()
|
||||
fig_manager.window.winfo_screenheight = Mock(return_value=1440)
|
||||
return fig_manager
|
||||
|
||||
with patch('matplotlib.pyplot.get_current_fig_manager',
|
||||
mock_fig_manager), patch('matplotlib.pyplot.show'):
|
||||
vis.imshow_infos(image, result, show=True)
|
||||
|
||||
shutil.rmtree(tmp_dir)
|
Loading…
Reference in New Issue