[Enhance] Better result visualization (#419)
* Imporve result visualization to support wait time and change the backend to matplotlib. * Add unit test for visualization * Add adaptive dpi function * Rename `imshow_cls_result` to `imshow_infos`. * Support str in `imshow_infos` * Improve docstring.pull/426/head
parent
192b79eea0
commit
5383787512
|
@ -1,7 +1,6 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -90,7 +89,7 @@ def inference_model(model, img):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def show_result_pyplot(model, img, result, fig_size=(15, 10)):
|
def show_result_pyplot(model, img, result, fig_size=(15, 10), wait_time=0):
|
||||||
"""Visualize the classification results on the image.
|
"""Visualize the classification results on the image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -98,10 +97,11 @@ def show_result_pyplot(model, img, result, fig_size=(15, 10)):
|
||||||
img (str or np.ndarray): Image filename or loaded image.
|
img (str or np.ndarray): Image filename or loaded image.
|
||||||
result (list): The classification result.
|
result (list): The classification result.
|
||||||
fig_size (tuple): Figure size of the pyplot figure.
|
fig_size (tuple): Figure size of the pyplot figure.
|
||||||
|
Defaults to (15, 10).
|
||||||
|
wait_time (int): How many seconds to display the image.
|
||||||
|
Defaults to 0.
|
||||||
"""
|
"""
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model = model.module
|
model = model.module
|
||||||
img = model.show_result(img, result, show=False)
|
model.show_result(
|
||||||
plt.figure(figsize=fig_size)
|
img, result, show=True, fig_size=fig_size, wait_time=wait_time)
|
||||||
plt.imshow(mmcv.bgr2rgb(img))
|
|
||||||
plt.show()
|
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .image import color_val_matplotlib, imshow_infos
|
||||||
|
|
||||||
|
__all__ = ['imshow_infos', 'color_val_matplotlib']
|
|
@ -0,0 +1,130 @@
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# A small value
|
||||||
|
EPS = 1e-2
|
||||||
|
|
||||||
|
|
||||||
|
def color_val_matplotlib(color):
|
||||||
|
"""Convert various input in BGR order to normalized RGB matplotlib color
|
||||||
|
tuples,
|
||||||
|
|
||||||
|
Args:
|
||||||
|
color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Color inputs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[float]: A tuple of 3 normalized floats indicating RGB channels.
|
||||||
|
"""
|
||||||
|
color = mmcv.color_val(color)
|
||||||
|
color = [color / 255 for color in color[::-1]]
|
||||||
|
return tuple(color)
|
||||||
|
|
||||||
|
|
||||||
|
def imshow_infos(img,
|
||||||
|
infos,
|
||||||
|
text_color='white',
|
||||||
|
font_size=26,
|
||||||
|
row_width=20,
|
||||||
|
win_name='',
|
||||||
|
show=True,
|
||||||
|
fig_size=(15, 10),
|
||||||
|
wait_time=0,
|
||||||
|
out_file=None):
|
||||||
|
"""Show image with extra infomation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (str | ndarray): The image to be displayed.
|
||||||
|
infos (dict): Extra infos to display in the image.
|
||||||
|
text_color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Extra infos
|
||||||
|
display color. Defaults to 'white'.
|
||||||
|
font_size (int): Extra infos display font size. Defaults to 26.
|
||||||
|
row_width (int): width between each row of results on the image.
|
||||||
|
win_name (str): The image title. Defaults to ''
|
||||||
|
show (bool): Whether to show the image. Defaults to True.
|
||||||
|
fig_size (tuple): Image show figure size. Defaults to (15, 10).
|
||||||
|
wait_time (int): How many seconds to display the image. Defaults to 0.
|
||||||
|
out_file (Optional[str]): The filename to write the image.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: The image with extra infomations.
|
||||||
|
"""
|
||||||
|
img = mmcv.imread(img).astype(np.uint8)
|
||||||
|
|
||||||
|
x, y = 3, row_width // 2
|
||||||
|
text_color = color_val_matplotlib(text_color)
|
||||||
|
|
||||||
|
img = mmcv.bgr2rgb(img)
|
||||||
|
width, height = img.shape[1], img.shape[0]
|
||||||
|
img = np.ascontiguousarray(img)
|
||||||
|
|
||||||
|
# A proper dpi for image save with default font size.
|
||||||
|
fig = plt.figure(win_name, frameon=False, figsize=fig_size, dpi=36)
|
||||||
|
plt.title(win_name)
|
||||||
|
canvas = fig.canvas
|
||||||
|
dpi = fig.get_dpi()
|
||||||
|
# add a small EPS to avoid precision lost due to matplotlib's truncation
|
||||||
|
# (https://github.com/matplotlib/matplotlib/issues/15363)
|
||||||
|
fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi)
|
||||||
|
|
||||||
|
# remove white edges by set subplot margin
|
||||||
|
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
||||||
|
ax = plt.gca()
|
||||||
|
ax.axis('off')
|
||||||
|
|
||||||
|
for k, v in infos.items():
|
||||||
|
if isinstance(v, float):
|
||||||
|
v = f'{v:.2f}'
|
||||||
|
label_text = f'{k}: {v}'
|
||||||
|
ax.text(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
f'{label_text}',
|
||||||
|
bbox={
|
||||||
|
'facecolor': 'black',
|
||||||
|
'alpha': 0.7,
|
||||||
|
'pad': 0.2,
|
||||||
|
'edgecolor': 'none',
|
||||||
|
'boxstyle': 'round'
|
||||||
|
},
|
||||||
|
color=text_color,
|
||||||
|
fontsize=font_size,
|
||||||
|
family='monospace',
|
||||||
|
verticalalignment='top',
|
||||||
|
horizontalalignment='left')
|
||||||
|
y += row_width
|
||||||
|
|
||||||
|
plt.imshow(img)
|
||||||
|
stream, _ = canvas.print_to_buffer()
|
||||||
|
buffer = np.frombuffer(stream, dtype='uint8')
|
||||||
|
img_rgba = buffer.reshape(height, width, 4)
|
||||||
|
rgb, _ = np.split(img_rgba, [3], axis=2)
|
||||||
|
img = rgb.astype('uint8')
|
||||||
|
img = mmcv.rgb2bgr(img)
|
||||||
|
|
||||||
|
if show:
|
||||||
|
# Matplotlib will adjust text size depends on window size and image
|
||||||
|
# aspect ratio. It's hard to get, so here we set an adaptive dpi
|
||||||
|
# according to screen height. 20 here is an empirical parameter.
|
||||||
|
fig_manager = plt.get_current_fig_manager()
|
||||||
|
if hasattr(fig_manager, 'window'):
|
||||||
|
# Figure manager doesn't have window if no screen.
|
||||||
|
screen_dpi = fig_manager.window.winfo_screenheight() // 20
|
||||||
|
fig.set_dpi(screen_dpi)
|
||||||
|
|
||||||
|
# We do not use cv2 for display because in some cases, opencv will
|
||||||
|
# conflict with Qt, it will output a warning: Current thread
|
||||||
|
# is not the object's thread. You can refer to
|
||||||
|
# https://github.com/opencv/opencv-python/issues/46 for details
|
||||||
|
if wait_time == 0:
|
||||||
|
plt.show()
|
||||||
|
else:
|
||||||
|
plt.show(block=False)
|
||||||
|
plt.pause(wait_time)
|
||||||
|
if out_file is not None:
|
||||||
|
mmcv.imwrite(img, out_file)
|
||||||
|
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
return img
|
|
@ -3,13 +3,13 @@ import warnings
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import cv2
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from mmcv import color_val
|
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
|
from mmcls.core.visualization import imshow_infos
|
||||||
|
|
||||||
# TODO import `auto_fp16` from mmcv and delete them from mmcls
|
# TODO import `auto_fp16` from mmcv and delete them from mmcls
|
||||||
try:
|
try:
|
||||||
from mmcv.runner import auto_fp16
|
from mmcv.runner import auto_fp16
|
||||||
|
@ -169,10 +169,11 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
|
||||||
def show_result(self,
|
def show_result(self,
|
||||||
img,
|
img,
|
||||||
result,
|
result,
|
||||||
text_color='green',
|
text_color='white',
|
||||||
font_scale=0.5,
|
font_scale=0.5,
|
||||||
row_width=20,
|
row_width=20,
|
||||||
show=False,
|
show=False,
|
||||||
|
fig_size=(15, 10),
|
||||||
win_name='',
|
win_name='',
|
||||||
wait_time=0,
|
wait_time=0,
|
||||||
out_file=None):
|
out_file=None):
|
||||||
|
@ -186,39 +187,29 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
|
||||||
row_width (int): width between each row of results on the image.
|
row_width (int): width between each row of results on the image.
|
||||||
show (bool): Whether to show the image.
|
show (bool): Whether to show the image.
|
||||||
Default: False.
|
Default: False.
|
||||||
|
fig_size (tuple): Image show figure size. Defaults to (15, 10).
|
||||||
win_name (str): The window name.
|
win_name (str): The window name.
|
||||||
wait_time (int): Value of waitKey param.
|
wait_time (int): How many seconds to display the image.
|
||||||
Default: 0.
|
Defaults to 0.
|
||||||
out_file (str or None): The filename to write the image.
|
out_file (str or None): The filename to write the image.
|
||||||
Default: None.
|
Default: None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
img (ndarray): Only if not `show` or `out_file`
|
img (ndarray): Image with overlayed results.
|
||||||
"""
|
"""
|
||||||
img = mmcv.imread(img)
|
img = mmcv.imread(img)
|
||||||
img = img.copy()
|
img = img.copy()
|
||||||
|
|
||||||
# write results on left-top of the image
|
img = imshow_infos(
|
||||||
x, y = 0, row_width
|
img,
|
||||||
text_color = color_val(text_color)
|
result,
|
||||||
for k, v in result.items():
|
text_color=text_color,
|
||||||
if isinstance(v, float):
|
font_size=int(font_scale * 50),
|
||||||
v = f'{v:.2f}'
|
row_width=row_width,
|
||||||
label_text = f'{k}: {v}'
|
win_name=win_name,
|
||||||
cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
|
show=show,
|
||||||
font_scale, text_color)
|
fig_size=fig_size,
|
||||||
y += row_width
|
wait_time=wait_time,
|
||||||
|
out_file=out_file)
|
||||||
|
|
||||||
# if out_file specified, do not show image in window
|
return img
|
||||||
if out_file is not None:
|
|
||||||
show = False
|
|
||||||
|
|
||||||
if show:
|
|
||||||
mmcv.imshow(img, win_name, wait_time)
|
|
||||||
if out_file is not None:
|
|
||||||
mmcv.imwrite(img, out_file)
|
|
||||||
|
|
||||||
if not (show or out_file):
|
|
||||||
warnings.warn('show==False and out_file is not specified, only '
|
|
||||||
'result image will be returned')
|
|
||||||
return img
|
|
||||||
|
|
|
@ -14,6 +14,6 @@ line_length = 79
|
||||||
multi_line_output = 0
|
multi_line_output = 0
|
||||||
known_standard_library = pkg_resources,setuptools
|
known_standard_library = pkg_resources,setuptools
|
||||||
known_first_party = mmcls
|
known_first_party = mmcls
|
||||||
known_third_party = PIL,cv2,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,seaborn,torch,torchvision,ts
|
known_third_party = PIL,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,seaborn,torch,torchvision,ts
|
||||||
no_lines_before = STDLIB,LOCALFOLDER
|
no_lines_before = STDLIB,LOCALFOLDER
|
||||||
default_section = THIRDPARTY
|
default_section = THIRDPARTY
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import tempfile
|
import tempfile
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -85,16 +84,6 @@ def test_image_classifier():
|
||||||
model.show_result(img, result, out_file=out_file)
|
model.show_result(img, result, out_file=out_file)
|
||||||
assert osp.exists(out_file)
|
assert osp.exists(out_file)
|
||||||
|
|
||||||
def save_show(_, *args):
|
|
||||||
out_path = osp.join(tmpdir, '_'.join([str(arg) for arg in args]))
|
|
||||||
with open(out_path, 'w') as f:
|
|
||||||
f.write('test')
|
|
||||||
|
|
||||||
with patch('mmcv.imshow', save_show):
|
|
||||||
model.show_result(
|
|
||||||
img, result, show=True, win_name='img', wait_time=5)
|
|
||||||
assert osp.exists(osp.join(tmpdir, 'img_5'))
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_classifier_with_mixup():
|
def test_image_classifier_with_mixup():
|
||||||
# Test mixup in ImageClassifier
|
# Test mixup in ImageClassifier
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright (c) Open-MMLab. All rights reserved.
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mmcls.core import visualization as vis
|
||||||
|
|
||||||
|
|
||||||
|
def test_color():
|
||||||
|
assert vis.color_val_matplotlib(mmcv.Color.blue) == (0., 0., 1.)
|
||||||
|
assert vis.color_val_matplotlib('green') == (0., 1., 0.)
|
||||||
|
assert vis.color_val_matplotlib((1, 2, 3)) == (3 / 255, 2 / 255, 1 / 255)
|
||||||
|
assert vis.color_val_matplotlib(100) == (100 / 255, 100 / 255, 100 / 255)
|
||||||
|
assert vis.color_val_matplotlib(np.zeros(3, dtype=int)) == (0., 0., 0.)
|
||||||
|
# forbid white color
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
vis.color_val_matplotlib([255, 255, 255])
|
||||||
|
# forbid float
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
vis.color_val_matplotlib(1.0)
|
||||||
|
# overflowed
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
vis.color_val_matplotlib((0, 0, 500))
|
||||||
|
|
||||||
|
|
||||||
|
def test_imshow_infos():
|
||||||
|
tmp_dir = osp.join(tempfile.gettempdir(), 'infos_image')
|
||||||
|
tmp_filename = osp.join(tmp_dir, 'image.jpg')
|
||||||
|
|
||||||
|
image = np.ones((10, 10, 3), np.uint8)
|
||||||
|
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
|
||||||
|
out_image = vis.imshow_infos(
|
||||||
|
image, result, out_file=tmp_filename, show=False)
|
||||||
|
assert osp.isfile(tmp_filename)
|
||||||
|
assert image.shape == out_image.shape
|
||||||
|
assert not np.allclose(image, out_image)
|
||||||
|
os.remove(tmp_filename)
|
||||||
|
|
||||||
|
# test grayscale images
|
||||||
|
image = np.ones((10, 10), np.uint8)
|
||||||
|
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
|
||||||
|
out_image = vis.imshow_infos(
|
||||||
|
image, result, out_file=tmp_filename, show=False)
|
||||||
|
assert osp.isfile(tmp_filename)
|
||||||
|
assert image.shape == out_image.shape[:2]
|
||||||
|
os.remove(tmp_filename)
|
||||||
|
|
||||||
|
# test show=True
|
||||||
|
image = np.ones((10, 10, 3), np.uint8)
|
||||||
|
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
|
||||||
|
|
||||||
|
def save_args(*args, **kwargs):
|
||||||
|
args_list = ['args']
|
||||||
|
args_list += [
|
||||||
|
str(arg) for arg in args if isinstance(arg, (str, bool, int))
|
||||||
|
]
|
||||||
|
args_list += [
|
||||||
|
f'{k}-{v}' for k, v in kwargs.items()
|
||||||
|
if isinstance(v, (str, bool, int))
|
||||||
|
]
|
||||||
|
out_path = osp.join(tmp_dir, '_'.join(args_list))
|
||||||
|
with open(out_path, 'w') as f:
|
||||||
|
f.write('test')
|
||||||
|
|
||||||
|
with patch('matplotlib.pyplot.show', save_args), \
|
||||||
|
patch('matplotlib.pyplot.pause', save_args):
|
||||||
|
vis.imshow_infos(image, result, show=True, wait_time=5)
|
||||||
|
assert osp.exists(osp.join(tmp_dir, 'args_block-False'))
|
||||||
|
assert osp.exists(osp.join(tmp_dir, 'args_5'))
|
||||||
|
|
||||||
|
vis.imshow_infos(image, result, show=True, wait_time=0)
|
||||||
|
assert osp.exists(osp.join(tmp_dir, 'args'))
|
||||||
|
|
||||||
|
# test adaptive dpi
|
||||||
|
def mock_fig_manager():
|
||||||
|
fig_manager = Mock()
|
||||||
|
fig_manager.window.winfo_screenheight = Mock(return_value=1440)
|
||||||
|
return fig_manager
|
||||||
|
|
||||||
|
with patch('matplotlib.pyplot.get_current_fig_manager',
|
||||||
|
mock_fig_manager), patch('matplotlib.pyplot.show'):
|
||||||
|
vis.imshow_infos(image, result, show=True)
|
||||||
|
|
||||||
|
shutil.rmtree(tmp_dir)
|
Loading…
Reference in New Issue