[Refactor] Use new API of matplotlib to handle blocking input in visualization. (#568)
* [Refactor] Use new API of matplotlib to handle blocking input in visualization. * Modify unit testspull/571/head^2
parent
33f049b4c2
commit
78d6d8503f
|
@ -1,11 +1,7 @@
|
|||
from threading import Timer
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from matplotlib.backend_bases import CloseEvent
|
||||
from matplotlib.blocking_input import BlockingInput
|
||||
|
||||
# A small value
|
||||
EPS = 1e-2
|
||||
|
@ -41,7 +37,7 @@ class BaseFigureContextManager:
|
|||
"""
|
||||
|
||||
def __init__(self, axis=False, fig_save_cfg={}, fig_show_cfg={}) -> None:
|
||||
self.is_inline = 'inline' in matplotlib.get_backend()
|
||||
self.is_inline = 'inline' in plt.get_backend()
|
||||
|
||||
# Because save and show need different figure size
|
||||
# We set two figure and axes to handle save and show
|
||||
|
@ -52,7 +48,6 @@ class BaseFigureContextManager:
|
|||
self.fig_show: plt.Figure = None
|
||||
self.fig_show_cfg = fig_show_cfg
|
||||
self.ax_show: plt.Axes = None
|
||||
self.blocking_input: BlockingInput = None
|
||||
|
||||
self.axis = axis
|
||||
|
||||
|
@ -83,8 +78,6 @@ class BaseFigureContextManager:
|
|||
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
||||
|
||||
self.fig_show, self.ax_show = fig, ax
|
||||
self.blocking_input = BlockingInput(
|
||||
self.fig_show, eventslist=('key_press_event', 'close_event'))
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if self.is_inline:
|
||||
|
@ -95,14 +88,6 @@ class BaseFigureContextManager:
|
|||
plt.close(self.fig_save)
|
||||
plt.close(self.fig_show)
|
||||
|
||||
try:
|
||||
# In matplotlib>=3.4.0, with TkAgg, plt.close will destroy
|
||||
# window after idle, need to update manually.
|
||||
# Refers to https://github.com/matplotlib/matplotlib/blob/v3.4.x/lib/matplotlib/backends/_backend_tk.py#L470 # noqa: E501
|
||||
self.fig_show.canvas.manager.window.update()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def prepare(self):
|
||||
if self.is_inline:
|
||||
# if use inline backend, just rebuild the fig_save.
|
||||
|
@ -121,29 +106,59 @@ class BaseFigureContextManager:
|
|||
self.ax_show.cla()
|
||||
self.ax_show.axis(self.axis)
|
||||
|
||||
def wait_continue(self, timeout=0):
|
||||
def wait_continue(self, timeout=0, continue_key=' ') -> int:
|
||||
"""Show the image and wait for the user's input.
|
||||
|
||||
This implementation refers to
|
||||
https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py
|
||||
|
||||
Args:
|
||||
timeout (int): If positive, continue after ``timeout`` seconds.
|
||||
Defaults to 0.
|
||||
continue_key (str): The key for users to continue. Defaults to
|
||||
the space key.
|
||||
|
||||
Returns:
|
||||
int: If zero, means time out or the user pressed ``continue_key``,
|
||||
and if one, means the user closed the show figure.
|
||||
""" # noqa: E501
|
||||
if self.is_inline:
|
||||
# If use inline backend, interactive input and timeout is no use.
|
||||
return
|
||||
|
||||
# In matplotlib==3.4.x, with TkAgg, official timeout api of
|
||||
# start_event_loop cannot work properly. Use a Timer to directly stop
|
||||
# event loop.
|
||||
if timeout > 0:
|
||||
timer = Timer(timeout, self.fig_show.canvas.stop_event_loop)
|
||||
timer.start()
|
||||
while True:
|
||||
# Disable matplotlib default hotkey to close figure.
|
||||
with plt.rc_context({'keymap.quit': []}):
|
||||
key_press = self.blocking_input(n=1, timeout=0)
|
||||
if self.fig_show.canvas.manager:
|
||||
# Ensure that the figure is shown
|
||||
self.fig_show.show()
|
||||
|
||||
# Timeout or figure is closed or press space or press 'q'
|
||||
if len(key_press) == 0 or isinstance(
|
||||
key_press[0],
|
||||
CloseEvent) or key_press[0].key in ['q', ' ']:
|
||||
break
|
||||
if timeout > 0:
|
||||
timer.cancel()
|
||||
while True:
|
||||
|
||||
# Connect the events to the handler function call.
|
||||
event = None
|
||||
|
||||
def handler(ev):
|
||||
# Set external event variable
|
||||
nonlocal event
|
||||
# Qt backend may fire two events at the same time,
|
||||
# use a condition to avoid missing close event.
|
||||
event = ev if not isinstance(event, CloseEvent) else event
|
||||
self.fig_show.canvas.stop_event_loop()
|
||||
|
||||
cids = [
|
||||
self.fig_show.canvas.mpl_connect(name, handler)
|
||||
for name in ('key_press_event', 'close_event')
|
||||
]
|
||||
|
||||
try:
|
||||
self.fig_show.canvas.start_event_loop(timeout)
|
||||
finally: # Run even on exception like ctrl-c.
|
||||
# Disconnect the callbacks.
|
||||
for cid in cids:
|
||||
self.fig_show.canvas.mpl_disconnect(cid)
|
||||
|
||||
if isinstance(event, CloseEvent):
|
||||
return 1 # Quit for close.
|
||||
elif event is None or event.key == continue_key:
|
||||
return 0 # Quit for continue.
|
||||
|
||||
|
||||
class ImshowInfosContextManager(BaseFigureContextManager):
|
||||
|
@ -259,6 +274,7 @@ class ImshowInfosContextManager(BaseFigureContextManager):
|
|||
if out_file is not None:
|
||||
mmcv.imwrite(img_save, out_file)
|
||||
|
||||
ret = 0
|
||||
if show and not self.is_inline:
|
||||
# Reserve some space for the tip.
|
||||
self.ax_show.set_title(win_name)
|
||||
|
@ -274,13 +290,13 @@ class ImshowInfosContextManager(BaseFigureContextManager):
|
|||
# Refresh canvas, necessary for Qt5 backend.
|
||||
self.fig_show.canvas.draw()
|
||||
|
||||
self.wait_continue(timeout=wait_time)
|
||||
ret = self.wait_continue(timeout=wait_time)
|
||||
elif (not show) and self.is_inline:
|
||||
# If use inline backend, we use fig_save to show the image
|
||||
# So we need to close it if users don't want to show.
|
||||
plt.close(self.fig_save)
|
||||
|
||||
return img_save
|
||||
return ret, img_save
|
||||
|
||||
|
||||
def imshow_infos(img,
|
||||
|
@ -313,7 +329,7 @@ def imshow_infos(img,
|
|||
np.ndarray: The image with extra infomations.
|
||||
"""
|
||||
with ImshowInfosContextManager(fig_size=fig_size) as manager:
|
||||
img = manager.put_img_infos(
|
||||
_, img = manager.put_img_infos(
|
||||
img,
|
||||
infos,
|
||||
text_color=text_color,
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mmcv
|
||||
|
@ -52,30 +51,8 @@ def test_imshow_infos():
|
|||
assert image.shape == out_image.shape[:2]
|
||||
os.remove(tmp_filename)
|
||||
|
||||
# test show=True
|
||||
image = np.ones((10, 10, 3), np.uint8)
|
||||
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
|
||||
|
||||
def mock_blocking_input(self, n=1, timeout=30):
|
||||
keypress = Mock()
|
||||
keypress.key = ' '
|
||||
out_path = osp.join(tmp_dir, '_'.join([str(n), str(timeout)]))
|
||||
with open(out_path, 'w') as f:
|
||||
f.write('test')
|
||||
return [keypress]
|
||||
|
||||
with patch('matplotlib.blocking_input.BlockingInput.__call__',
|
||||
mock_blocking_input):
|
||||
vis.imshow_infos(image, result, show=True, wait_time=5)
|
||||
assert osp.exists(osp.join(tmp_dir, '1_0'))
|
||||
|
||||
shutil.rmtree(tmp_dir)
|
||||
|
||||
|
||||
@patch(
|
||||
'matplotlib.blocking_input.BlockingInput.__call__',
|
||||
return_value=[Mock(key=' ')])
|
||||
def test_context_manager(mock_blocking_input):
|
||||
def test_figure_context_manager():
|
||||
# test show multiple images with the same figure.
|
||||
images = [
|
||||
np.random.randint(0, 255, (100, 100, 3), np.uint8) for _ in range(5)
|
||||
|
@ -85,22 +62,39 @@ def test_context_manager(mock_blocking_input):
|
|||
with vis.ImshowInfosContextManager() as manager:
|
||||
fig_show = manager.fig_show
|
||||
fig_save = manager.fig_save
|
||||
|
||||
# Test time out
|
||||
fig_show.canvas.start_event_loop = MagicMock()
|
||||
fig_show.canvas.end_event_loop = MagicMock()
|
||||
for image in images:
|
||||
out_image = manager.put_img_infos(image, result, show=True)
|
||||
ret, out_image = manager.put_img_infos(image, result, show=True)
|
||||
assert ret == 0
|
||||
assert image.shape == out_image.shape
|
||||
assert not np.allclose(image, out_image)
|
||||
assert fig_show is manager.fig_show
|
||||
assert fig_save is manager.fig_save
|
||||
|
||||
# test rebuild figure if user destroyed it.
|
||||
with vis.ImshowInfosContextManager() as manager:
|
||||
fig_save = manager.fig_save
|
||||
# Test continue key
|
||||
fig_show.canvas.start_event_loop = (
|
||||
lambda _: fig_show.canvas.key_press_event(' '))
|
||||
for image in images:
|
||||
fig_show = manager.fig_show
|
||||
plt.close(manager.fig_show)
|
||||
|
||||
out_image = manager.put_img_infos(image, result, show=True)
|
||||
ret, out_image = manager.put_img_infos(image, result, show=True)
|
||||
assert ret == 0
|
||||
assert image.shape == out_image.shape
|
||||
assert not np.allclose(image, out_image)
|
||||
assert not (fig_show is manager.fig_show)
|
||||
assert fig_show is manager.fig_show
|
||||
assert fig_save is manager.fig_save
|
||||
|
||||
# Test close figure manually
|
||||
fig_show = manager.fig_show
|
||||
|
||||
def destroy(*_, **__):
|
||||
fig_show.canvas.close_event()
|
||||
plt.close(fig_show)
|
||||
|
||||
fig_show.canvas.start_event_loop = destroy
|
||||
ret, out_image = manager.put_img_infos(images[0], result, show=True)
|
||||
assert ret == 1
|
||||
assert image.shape == out_image.shape
|
||||
assert not np.allclose(image, out_image)
|
||||
assert fig_save is manager.fig_save
|
||||
|
|
|
@ -238,7 +238,7 @@ def main():
|
|||
|
||||
infos = dict(label=CLASSES[item['gt_label']])
|
||||
|
||||
manager.put_img_infos(
|
||||
ret, _ = manager.put_img_infos(
|
||||
image,
|
||||
infos,
|
||||
font_size=20,
|
||||
|
@ -248,6 +248,10 @@ def main():
|
|||
|
||||
progressBar.update()
|
||||
|
||||
if ret == 1:
|
||||
print('\nMannualy interrupted.')
|
||||
break
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue