mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Visualizer] use FigureManager to manage figure to avoid affecting plt.show() outside Visualizer(#440)
* figure in Visualizer is not managed by plt * encapsulate code and remove unused code
This commit is contained in:
parent
b75962a660
commit
bb56cf42ab
@ -8,9 +8,12 @@ import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
from matplotlib.collections import (LineCollection, PatchCollection,
|
||||
PolyCollection)
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.patches import Circle
|
||||
from matplotlib.pyplot import new_figure_manager
|
||||
|
||||
from mmengine.config import Config
|
||||
from mmengine.data import BaseDataElement
|
||||
@ -157,7 +160,7 @@ class Visualizer(ManagerMixin):
|
||||
vis_backends: Optional[List[Dict]] = None,
|
||||
save_dir: Optional[str] = None,
|
||||
fig_save_cfg=dict(frameon=False),
|
||||
fig_show_cfg=dict(frameon=False, num='show')
|
||||
fig_show_cfg=dict(frameon=False)
|
||||
) -> None:
|
||||
super().__init__(name)
|
||||
self._dataset_meta: Optional[dict] = None
|
||||
@ -196,17 +199,12 @@ class Visualizer(ManagerMixin):
|
||||
vis_backend.setdefault('save_dir', save_dir)
|
||||
self._vis_backends[name] = VISBACKENDS.build(vis_backend)
|
||||
|
||||
self.is_inline = 'inline' in plt.get_backend()
|
||||
|
||||
self.fig_save = None
|
||||
self.fig_show = None
|
||||
self.fig_save_num = fig_save_cfg.get('num', None)
|
||||
self.fig_show_num = fig_show_cfg.get('num', None)
|
||||
self.fig_save_cfg = fig_save_cfg
|
||||
self.fig_show_cfg = fig_show_cfg
|
||||
|
||||
(self.fig_save, self.ax_save,
|
||||
self.fig_save_num) = self._initialize_fig(fig_save_cfg)
|
||||
(self.fig_save_canvas, self.fig_save,
|
||||
self.ax_save) = self._initialize_fig(fig_save_cfg)
|
||||
self.dpi = self.fig_save.get_dpi()
|
||||
|
||||
if image is not None:
|
||||
@ -242,20 +240,22 @@ class Visualizer(ManagerMixin):
|
||||
continue_key (str): The key for users to continue. Defaults to
|
||||
the space key.
|
||||
"""
|
||||
if self.is_inline:
|
||||
return
|
||||
if self.fig_show is None or not plt.fignum_exists(self.fig_show_num):
|
||||
(self.fig_show, self.ax_show,
|
||||
self.fig_show_num) = self._initialize_fig(self.fig_show_cfg)
|
||||
is_inline = 'inline' in plt.get_backend()
|
||||
img = self.get_image() if drawn_img is None else drawn_img
|
||||
self.ax_show.cla()
|
||||
self.ax_show.axis(False)
|
||||
self.fig_show.canvas.manager.set_window_title(win_name) # type: ignore
|
||||
# Refresh canvas, necessary for Qt5 backend.
|
||||
self.ax_show.imshow(img)
|
||||
self.fig_show.canvas.draw() # type: ignore
|
||||
wait_continue(
|
||||
self.fig_show, timeout=wait_time, continue_key=continue_key)
|
||||
self._init_manager(win_name)
|
||||
fig = self.manager.canvas.figure
|
||||
# remove white edges by set subplot margin
|
||||
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
||||
fig.clear()
|
||||
ax = fig.add_subplot()
|
||||
ax.axis(False)
|
||||
ax.imshow(img)
|
||||
self.manager.canvas.draw()
|
||||
|
||||
# Find a better way for inline to show the image
|
||||
if is_inline:
|
||||
return fig
|
||||
wait_continue(fig, timeout=wait_time, continue_key=continue_key)
|
||||
|
||||
@master_only
|
||||
def set_image(self, image: np.ndarray) -> None:
|
||||
@ -291,7 +291,7 @@ class Visualizer(ManagerMixin):
|
||||
np.ndarray: the drawn image which channel is RGB.
|
||||
"""
|
||||
assert self._image is not None, 'Please set image using `set_image`'
|
||||
return img_from_canvas(self.fig_save.canvas) # type: ignore
|
||||
return img_from_canvas(self.fig_save_canvas) # type: ignore
|
||||
|
||||
def _initialize_fig(self, fig_cfg) -> tuple:
|
||||
"""Build figure according to fig_cfg.
|
||||
@ -300,15 +300,34 @@ class Visualizer(ManagerMixin):
|
||||
fig_cfg (dict): The config to build figure.
|
||||
|
||||
Returns:
|
||||
tuple: build figure, axes and fig number.
|
||||
tuple: build canvas figure and axes.
|
||||
"""
|
||||
fig = plt.figure(**fig_cfg)
|
||||
|
||||
fig = Figure(**fig_cfg)
|
||||
ax = fig.add_subplot()
|
||||
ax.axis(False)
|
||||
|
||||
# remove white edges by set subplot margin
|
||||
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
||||
return (fig, ax, fig.number)
|
||||
canvas = FigureCanvasAgg(fig)
|
||||
return canvas, fig, ax
|
||||
|
||||
def _init_manager(self, win_name: str) -> None:
|
||||
"""Initialize the matplot manager.
|
||||
|
||||
Args:
|
||||
win_name (str): The window name.
|
||||
"""
|
||||
if getattr(self, 'manager', None) is None:
|
||||
self.manager = new_figure_manager(
|
||||
num=1, FigureClass=Figure, **self.fig_show_cfg)
|
||||
|
||||
try:
|
||||
self.manager.set_window_title(win_name)
|
||||
except Exception:
|
||||
self.manager = new_figure_manager(
|
||||
num=1, FigureClass=Figure, **self.fig_show_cfg)
|
||||
self.manager.set_window_title(win_name)
|
||||
|
||||
@master_only
|
||||
def get_backend(self, name) -> 'BaseVisBackend':
|
||||
@ -982,7 +1001,9 @@ class Visualizer(ManagerMixin):
|
||||
axes.imshow(
|
||||
convert_overlay_heatmap(topk_featmap[i], overlaid_image,
|
||||
alpha))
|
||||
return img_from_canvas(fig.canvas)
|
||||
image = img_from_canvas(fig.canvas)
|
||||
plt.close(fig)
|
||||
return image
|
||||
|
||||
@master_only
|
||||
def add_config(self, config: Config, **kwargs):
|
||||
@ -1071,9 +1092,6 @@ class Visualizer(ManagerMixin):
|
||||
|
||||
def close(self) -> None:
|
||||
"""close an opened object."""
|
||||
plt.close(self.fig_save)
|
||||
if self.fig_show is not None:
|
||||
plt.close(self.fig_show)
|
||||
for vis_backend in self._vis_backends.values():
|
||||
vis_backend.close()
|
||||
|
||||
|
@ -4,7 +4,6 @@ import time
|
||||
from typing import Any
|
||||
from unittest import TestCase
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
@ -171,12 +170,10 @@ class TestVisualizer(TestCase):
|
||||
image=self.image,
|
||||
vis_backends=copy.deepcopy(self.vis_backend_cfg),
|
||||
save_dir='temp_dir')
|
||||
fig_num = visualizer.fig_save_num
|
||||
assert fig_num in plt.get_fignums()
|
||||
|
||||
for name in ['mock1', 'mock2']:
|
||||
assert visualizer.get_backend(name)._close is False
|
||||
visualizer.close()
|
||||
assert fig_num not in plt.get_fignums()
|
||||
for name in ['mock1', 'mock2']:
|
||||
assert visualizer.get_backend(name)._close is True
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user