[Refactor] Use new API of matplotlib to handle blocking input in visualization. (#568)

* [Refactor] Use new API of matplotlib to handle blocking input in
visualization.

* Modify unit tests
pull/571/head^2
Ma Zerun 2021-12-02 17:46:40 +08:00 committed by GitHub
parent 33f049b4c2
commit 78d6d8503f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 86 additions and 72 deletions

View File

@ -1,11 +1,7 @@
from threading import Timer
import matplotlib
import matplotlib.pyplot as plt
import mmcv
import numpy as np
from matplotlib.backend_bases import CloseEvent
from matplotlib.blocking_input import BlockingInput
# A small value
EPS = 1e-2
@ -41,7 +37,7 @@ class BaseFigureContextManager:
"""
def __init__(self, axis=False, fig_save_cfg={}, fig_show_cfg={}) -> None:
self.is_inline = 'inline' in matplotlib.get_backend()
self.is_inline = 'inline' in plt.get_backend()
# Because save and show need different figure size
# We set two figure and axes to handle save and show
@ -52,7 +48,6 @@ class BaseFigureContextManager:
self.fig_show: plt.Figure = None
self.fig_show_cfg = fig_show_cfg
self.ax_show: plt.Axes = None
self.blocking_input: BlockingInput = None
self.axis = axis
@ -83,8 +78,6 @@ class BaseFigureContextManager:
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
self.fig_show, self.ax_show = fig, ax
self.blocking_input = BlockingInput(
self.fig_show, eventslist=('key_press_event', 'close_event'))
def __exit__(self, exc_type, exc_value, traceback):
if self.is_inline:
@ -95,14 +88,6 @@ class BaseFigureContextManager:
plt.close(self.fig_save)
plt.close(self.fig_show)
try:
# In matplotlib>=3.4.0, with TkAgg, plt.close will destroy
# window after idle, need to update manually.
# Refers to https://github.com/matplotlib/matplotlib/blob/v3.4.x/lib/matplotlib/backends/_backend_tk.py#L470 # noqa: E501
self.fig_show.canvas.manager.window.update()
except AttributeError:
pass
def prepare(self):
if self.is_inline:
# if use inline backend, just rebuild the fig_save.
@ -121,29 +106,59 @@ class BaseFigureContextManager:
self.ax_show.cla()
self.ax_show.axis(self.axis)
def wait_continue(self, timeout=0):
def wait_continue(self, timeout=0, continue_key=' ') -> int:
"""Show the image and wait for the user's input.
This implementation refers to
https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py
Args:
timeout (int): If positive, continue after ``timeout`` seconds.
Defaults to 0.
continue_key (str): The key for users to continue. Defaults to
the space key.
Returns:
int: If zero, means time out or the user pressed ``continue_key``,
and if one, means the user closed the show figure.
""" # noqa: E501
if self.is_inline:
# If use inline backend, interactive input and timeout is no use.
return
# In matplotlib==3.4.x, with TkAgg, official timeout api of
# start_event_loop cannot work properly. Use a Timer to directly stop
# event loop.
if timeout > 0:
timer = Timer(timeout, self.fig_show.canvas.stop_event_loop)
timer.start()
while True:
# Disable matplotlib default hotkey to close figure.
with plt.rc_context({'keymap.quit': []}):
key_press = self.blocking_input(n=1, timeout=0)
if self.fig_show.canvas.manager:
# Ensure that the figure is shown
self.fig_show.show()
# Timeout or figure is closed or press space or press 'q'
if len(key_press) == 0 or isinstance(
key_press[0],
CloseEvent) or key_press[0].key in ['q', ' ']:
break
if timeout > 0:
timer.cancel()
while True:
# Connect the events to the handler function call.
event = None
def handler(ev):
# Set external event variable
nonlocal event
# Qt backend may fire two events at the same time,
# use a condition to avoid missing close event.
event = ev if not isinstance(event, CloseEvent) else event
self.fig_show.canvas.stop_event_loop()
cids = [
self.fig_show.canvas.mpl_connect(name, handler)
for name in ('key_press_event', 'close_event')
]
try:
self.fig_show.canvas.start_event_loop(timeout)
finally: # Run even on exception like ctrl-c.
# Disconnect the callbacks.
for cid in cids:
self.fig_show.canvas.mpl_disconnect(cid)
if isinstance(event, CloseEvent):
return 1 # Quit for close.
elif event is None or event.key == continue_key:
return 0 # Quit for continue.
class ImshowInfosContextManager(BaseFigureContextManager):
@ -259,6 +274,7 @@ class ImshowInfosContextManager(BaseFigureContextManager):
if out_file is not None:
mmcv.imwrite(img_save, out_file)
ret = 0
if show and not self.is_inline:
# Reserve some space for the tip.
self.ax_show.set_title(win_name)
@ -274,13 +290,13 @@ class ImshowInfosContextManager(BaseFigureContextManager):
# Refresh canvas, necessary for Qt5 backend.
self.fig_show.canvas.draw()
self.wait_continue(timeout=wait_time)
ret = self.wait_continue(timeout=wait_time)
elif (not show) and self.is_inline:
# If use inline backend, we use fig_save to show the image
# So we need to close it if users don't want to show.
plt.close(self.fig_save)
return img_save
return ret, img_save
def imshow_infos(img,
@ -313,7 +329,7 @@ def imshow_infos(img,
np.ndarray: The image with extra infomations.
"""
with ImshowInfosContextManager(fig_size=fig_size) as manager:
img = manager.put_img_infos(
_, img = manager.put_img_infos(
img,
infos,
text_color=text_color,

View File

@ -1,9 +1,8 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os
import os.path as osp
import shutil
import tempfile
from unittest.mock import Mock, patch
from unittest.mock import MagicMock
import matplotlib.pyplot as plt
import mmcv
@ -52,30 +51,8 @@ def test_imshow_infos():
assert image.shape == out_image.shape[:2]
os.remove(tmp_filename)
# test show=True
image = np.ones((10, 10, 3), np.uint8)
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
def mock_blocking_input(self, n=1, timeout=30):
keypress = Mock()
keypress.key = ' '
out_path = osp.join(tmp_dir, '_'.join([str(n), str(timeout)]))
with open(out_path, 'w') as f:
f.write('test')
return [keypress]
with patch('matplotlib.blocking_input.BlockingInput.__call__',
mock_blocking_input):
vis.imshow_infos(image, result, show=True, wait_time=5)
assert osp.exists(osp.join(tmp_dir, '1_0'))
shutil.rmtree(tmp_dir)
@patch(
'matplotlib.blocking_input.BlockingInput.__call__',
return_value=[Mock(key=' ')])
def test_context_manager(mock_blocking_input):
def test_figure_context_manager():
# test show multiple images with the same figure.
images = [
np.random.randint(0, 255, (100, 100, 3), np.uint8) for _ in range(5)
@ -85,22 +62,39 @@ def test_context_manager(mock_blocking_input):
with vis.ImshowInfosContextManager() as manager:
fig_show = manager.fig_show
fig_save = manager.fig_save
# Test time out
fig_show.canvas.start_event_loop = MagicMock()
fig_show.canvas.end_event_loop = MagicMock()
for image in images:
out_image = manager.put_img_infos(image, result, show=True)
ret, out_image = manager.put_img_infos(image, result, show=True)
assert ret == 0
assert image.shape == out_image.shape
assert not np.allclose(image, out_image)
assert fig_show is manager.fig_show
assert fig_save is manager.fig_save
# test rebuild figure if user destroyed it.
with vis.ImshowInfosContextManager() as manager:
fig_save = manager.fig_save
# Test continue key
fig_show.canvas.start_event_loop = (
lambda _: fig_show.canvas.key_press_event(' '))
for image in images:
fig_show = manager.fig_show
plt.close(manager.fig_show)
out_image = manager.put_img_infos(image, result, show=True)
ret, out_image = manager.put_img_infos(image, result, show=True)
assert ret == 0
assert image.shape == out_image.shape
assert not np.allclose(image, out_image)
assert not (fig_show is manager.fig_show)
assert fig_show is manager.fig_show
assert fig_save is manager.fig_save
# Test close figure manually
fig_show = manager.fig_show
def destroy(*_, **__):
fig_show.canvas.close_event()
plt.close(fig_show)
fig_show.canvas.start_event_loop = destroy
ret, out_image = manager.put_img_infos(images[0], result, show=True)
assert ret == 1
assert image.shape == out_image.shape
assert not np.allclose(image, out_image)
assert fig_save is manager.fig_save

View File

@ -238,7 +238,7 @@ def main():
infos = dict(label=CLASSES[item['gt_label']])
manager.put_img_infos(
ret, _ = manager.put_img_infos(
image,
infos,
font_size=20,
@ -248,6 +248,10 @@ def main():
progressBar.update()
if ret == 1:
print('\nMannualy interrupted.')
break
if __name__ == '__main__':
main()