[Feature] Add `ClsVisualizer`.
parent
27e685fe10
commit
0537c4d70c
|
@ -4,3 +4,4 @@ from .evaluation import * # noqa: F401, F403
|
|||
from .hook import * # noqa: F401, F403
|
||||
from .optimizers import * # noqa: F401, F403
|
||||
from .utils import * # noqa: F401, F403
|
||||
from .visualization import * # noqa: F401, F403
|
||||
|
|
|
@ -10,8 +10,8 @@ import mmcv
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import EpochBasedRunner, get_dist_info
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.logging import print_log
|
||||
from torch.functional import Tensor
|
||||
from torch.nn import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
|
|
@ -1,8 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .image import (BaseFigureContextManager, ImshowInfosContextManager,
|
||||
color_val_matplotlib, imshow_infos)
|
||||
from .cls_visualizer import ClsVisualizer
|
||||
|
||||
__all__ = [
|
||||
'BaseFigureContextManager', 'ImshowInfosContextManager', 'imshow_infos',
|
||||
'color_val_matplotlib'
|
||||
]
|
||||
__all__ = ['ClsVisualizer']
|
||||
|
|
|
@ -0,0 +1,153 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmengine import Visualizer
|
||||
from mmengine.dist import master_only
|
||||
|
||||
from mmcls.core import ClsDataSample
|
||||
from mmcls.registry import VISUALIZERS
|
||||
|
||||
|
||||
@VISUALIZERS.register_module()
|
||||
class ClsVisualizer(Visualizer):
|
||||
"""Universal Visualizer for classification task.
|
||||
|
||||
Args:
|
||||
name (str): Name of the instance. Defaults to 'visualizer'.
|
||||
image (np.ndarray, optional): the origin image to draw. The format
|
||||
should be RGB. Defaults to None.
|
||||
vis_backends (list, optional): Visual backend config list.
|
||||
Default to None.
|
||||
save_dir (str, optional): Save file dir for all storage backends.
|
||||
If it is None, the backend storage will not save any data.
|
||||
fig_save_cfg (dict): Keyword parameters of figure for saving.
|
||||
Defaults to empty dict.
|
||||
fig_show_cfg (dict): Keyword parameters of figure for showing.
|
||||
Defaults to empty dict.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> import mmcv
|
||||
>>> from pathlib import Path
|
||||
>>> from mmcls.core import ClsDataSample, ClsVisualizer
|
||||
>>> # Example image
|
||||
>>> img = mmcv.imread("./demo/bird.JPEG", channel_order='rgb')
|
||||
>>> # Example annotation
|
||||
>>> data_sample = ClsDataSample().set_gt_label(1).set_pred_label(1).\
|
||||
... set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
|
||||
>>> # Setup the visualizer
|
||||
>>> vis = ClsVisualizer(
|
||||
... save_dir="./outputs",
|
||||
... vis_backends=[dict(type='LocalVisBackend')])
|
||||
>>> # Set classes names
|
||||
>>> vis.dataset_meta = {'CLASSES': ['cat', 'bird', 'dog']}
|
||||
>>> # Show the example image with annotation in a figure.
|
||||
>>> # And it will ignore all preset storage backends.
|
||||
>>> vis.add_datasample('res', img, data_sample, show=True)
|
||||
>>> # Save the visualization result by the specified storage backends.
|
||||
>>> vis.add_datasample('res', img, data_sample)
|
||||
>>> assert Path('./outputs/vis_data/vis_image/res_0.png').exists()
|
||||
>>> # Save another visualization result with the same name.
|
||||
>>> vis.add_datasample('res', img, data_sample, step=1)
|
||||
>>> assert Path('./outputs/vis_data/vis_image/res_1.png').exists()
|
||||
"""
|
||||
|
||||
@master_only
|
||||
def add_datasample(self,
|
||||
name: str,
|
||||
image: np.ndarray,
|
||||
data_sample: Optional[ClsDataSample] = None,
|
||||
draw_gt: bool = True,
|
||||
draw_pred: bool = True,
|
||||
draw_score: bool = True,
|
||||
show: bool = False,
|
||||
text_cfg: dict = dict(),
|
||||
wait_time: float = 0,
|
||||
out_file: Optional[str] = None,
|
||||
step: int = 0) -> None:
|
||||
"""Draw datasample and save to all backends.
|
||||
|
||||
- If ``show`` is True, all storage backends are ignored and then
|
||||
displayed in a local window.
|
||||
- If the ``out_file`` parameter is specified, the drawn image
|
||||
will be additionally saved to ``out_file``. It is usually used
|
||||
in script mode like ``image_demo.py``
|
||||
|
||||
Args:
|
||||
name (str): The image identifier.
|
||||
image (np.ndarray): The image to draw.
|
||||
data_sample (:obj:`ClsDataSample`, optional): The annotation of the
|
||||
image. Default to None.
|
||||
draw_gt (bool): Whether to draw ground truth labels.
|
||||
Default to True.
|
||||
draw_pred (bool): Whether to draw prediction labels.
|
||||
Default to True.
|
||||
draw_score (bool): Whether to draw the prediction scores
|
||||
of prediction categories. Default to True.
|
||||
show (bool): Whether to display the drawn image. Default to False.
|
||||
text_cfg (dict): Extra text setting, which accepts
|
||||
arguments of :attr:`mmengine.Visualizer.draw_texts`.
|
||||
Defaults to an empty dict.
|
||||
wait_time (float): The interval of show (s). Default to 0, which
|
||||
means "forever".
|
||||
out_file (str, optional): Extra path to save the visualization
|
||||
result. Whether specified or not, the visualizer will still
|
||||
save the results by its storage backends. Default to None.
|
||||
step (int): Global step value to record. Default to 0.
|
||||
"""
|
||||
classes = None
|
||||
if self.dataset_meta is not None:
|
||||
classes = self.dataset_meta.get('CLASSES', None)
|
||||
|
||||
texts = []
|
||||
self.set_image(image)
|
||||
|
||||
if draw_gt and 'gt_label' in data_sample:
|
||||
gt_label = data_sample.gt_label
|
||||
idx = gt_label.label.tolist()
|
||||
class_labels = [''] * len(idx)
|
||||
if classes is not None:
|
||||
class_labels = [f' ({classes[i]})' for i in idx]
|
||||
labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))]
|
||||
prefix = 'Ground truth: '
|
||||
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
|
||||
|
||||
if draw_pred and 'pred_label' in data_sample:
|
||||
pred_label = data_sample.pred_label
|
||||
idx = pred_label.label.tolist()
|
||||
score_labels = [''] * len(idx)
|
||||
class_labels = [''] * len(idx)
|
||||
if draw_score and 'score' in pred_label:
|
||||
score_labels = [
|
||||
f', {pred_label.score[i].item():.2f}' for i in idx
|
||||
]
|
||||
|
||||
if classes is not None:
|
||||
class_labels = [f' ({classes[i]})' for i in idx]
|
||||
|
||||
labels = [
|
||||
str(idx[i]) + score_labels[i] + class_labels[i]
|
||||
for i in range(len(idx))
|
||||
]
|
||||
prefix = 'Prediction: '
|
||||
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
|
||||
|
||||
text_cfg = {
|
||||
'positions': np.array([(5, 5)]),
|
||||
'font_families': 'monospace',
|
||||
'colors': 'white',
|
||||
'bboxes': dict(facecolor='black', alpha=0.5, boxstyle='Round'),
|
||||
**text_cfg
|
||||
}
|
||||
self.draw_texts('\n'.join(texts), **text_cfg)
|
||||
drawn_img = self.get_image()
|
||||
|
||||
if show:
|
||||
self.show(drawn_img, win_name=name, wait_time=wait_time)
|
||||
else:
|
||||
self.add_image(name, drawn_img, step=step)
|
||||
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(drawn_img[..., ::-1], out_file)
|
|
@ -1,343 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import matplotlib.pyplot as plt
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from matplotlib.backend_bases import CloseEvent
|
||||
|
||||
# A small value
|
||||
EPS = 1e-2
|
||||
|
||||
|
||||
def color_val_matplotlib(color):
|
||||
"""Convert various input in BGR order to normalized RGB matplotlib color
|
||||
tuples,
|
||||
|
||||
Args:
|
||||
color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Color inputs
|
||||
|
||||
Returns:
|
||||
tuple[float]: A tuple of 3 normalized floats indicating RGB channels.
|
||||
"""
|
||||
color = mmcv.color_val(color)
|
||||
color = [color / 255 for color in color[::-1]]
|
||||
return tuple(color)
|
||||
|
||||
|
||||
class BaseFigureContextManager:
|
||||
"""Context Manager to reuse matplotlib figure.
|
||||
|
||||
It provides a figure for saving and a figure for showing to support
|
||||
different settings.
|
||||
|
||||
Args:
|
||||
axis (bool): Whether to show the axis lines.
|
||||
fig_save_cfg (dict): Keyword parameters of figure for saving.
|
||||
Defaults to empty dict.
|
||||
fig_show_cfg (dict): Keyword parameters of figure for showing.
|
||||
Defaults to empty dict.
|
||||
"""
|
||||
|
||||
def __init__(self, axis=False, fig_save_cfg={}, fig_show_cfg={}) -> None:
|
||||
self.is_inline = 'inline' in plt.get_backend()
|
||||
|
||||
# Because save and show need different figure size
|
||||
# We set two figure and axes to handle save and show
|
||||
self.fig_save: plt.Figure = None
|
||||
self.fig_save_cfg = fig_save_cfg
|
||||
self.ax_save: plt.Axes = None
|
||||
|
||||
self.fig_show: plt.Figure = None
|
||||
self.fig_show_cfg = fig_show_cfg
|
||||
self.ax_show: plt.Axes = None
|
||||
|
||||
self.axis = axis
|
||||
|
||||
def __enter__(self):
|
||||
if not self.is_inline:
|
||||
# If use inline backend, we cannot control which figure to show,
|
||||
# so disable the interactive fig_show, and put the initialization
|
||||
# of fig_save to `prepare` function.
|
||||
self._initialize_fig_save()
|
||||
self._initialize_fig_show()
|
||||
return self
|
||||
|
||||
def _initialize_fig_save(self):
|
||||
fig = plt.figure(**self.fig_save_cfg)
|
||||
ax = fig.add_subplot()
|
||||
|
||||
# remove white edges by set subplot margin
|
||||
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
||||
|
||||
self.fig_save, self.ax_save = fig, ax
|
||||
|
||||
def _initialize_fig_show(self):
|
||||
# fig_save will be resized to image size, only fig_show needs fig_size.
|
||||
fig = plt.figure(**self.fig_show_cfg)
|
||||
ax = fig.add_subplot()
|
||||
|
||||
# remove white edges by set subplot margin
|
||||
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
||||
|
||||
self.fig_show, self.ax_show = fig, ax
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if self.is_inline:
|
||||
# If use inline backend, whether to close figure depends on if
|
||||
# users want to show the image.
|
||||
return
|
||||
|
||||
plt.close(self.fig_save)
|
||||
plt.close(self.fig_show)
|
||||
|
||||
def prepare(self):
|
||||
if self.is_inline:
|
||||
# if use inline backend, just rebuild the fig_save.
|
||||
self._initialize_fig_save()
|
||||
self.ax_save.cla()
|
||||
self.ax_save.axis(self.axis)
|
||||
return
|
||||
|
||||
# If users force to destroy the window, rebuild fig_show.
|
||||
if not plt.fignum_exists(self.fig_show.number):
|
||||
self._initialize_fig_show()
|
||||
|
||||
# Clear all axes
|
||||
self.ax_save.cla()
|
||||
self.ax_save.axis(self.axis)
|
||||
self.ax_show.cla()
|
||||
self.ax_show.axis(self.axis)
|
||||
|
||||
def wait_continue(self, timeout=0, continue_key=' ') -> int:
|
||||
"""Show the image and wait for the user's input.
|
||||
|
||||
This implementation refers to
|
||||
https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py
|
||||
|
||||
Args:
|
||||
timeout (int): If positive, continue after ``timeout`` seconds.
|
||||
Defaults to 0.
|
||||
continue_key (str): The key for users to continue. Defaults to
|
||||
the space key.
|
||||
|
||||
Returns:
|
||||
int: If zero, means time out or the user pressed ``continue_key``,
|
||||
and if one, means the user closed the show figure.
|
||||
""" # noqa: E501
|
||||
if self.is_inline:
|
||||
# If use inline backend, interactive input and timeout is no use.
|
||||
return
|
||||
|
||||
if self.fig_show.canvas.manager:
|
||||
# Ensure that the figure is shown
|
||||
self.fig_show.show()
|
||||
|
||||
while True:
|
||||
|
||||
# Connect the events to the handler function call.
|
||||
event = None
|
||||
|
||||
def handler(ev):
|
||||
# Set external event variable
|
||||
nonlocal event
|
||||
# Qt backend may fire two events at the same time,
|
||||
# use a condition to avoid missing close event.
|
||||
event = ev if not isinstance(event, CloseEvent) else event
|
||||
self.fig_show.canvas.stop_event_loop()
|
||||
|
||||
cids = [
|
||||
self.fig_show.canvas.mpl_connect(name, handler)
|
||||
for name in ('key_press_event', 'close_event')
|
||||
]
|
||||
|
||||
try:
|
||||
self.fig_show.canvas.start_event_loop(timeout)
|
||||
finally: # Run even on exception like ctrl-c.
|
||||
# Disconnect the callbacks.
|
||||
for cid in cids:
|
||||
self.fig_show.canvas.mpl_disconnect(cid)
|
||||
|
||||
if isinstance(event, CloseEvent):
|
||||
return 1 # Quit for close.
|
||||
elif event is None or event.key == continue_key:
|
||||
return 0 # Quit for continue.
|
||||
|
||||
|
||||
class ImshowInfosContextManager(BaseFigureContextManager):
|
||||
"""Context Manager to reuse matplotlib figure and put infos on images.
|
||||
|
||||
Args:
|
||||
fig_size (tuple[int]): Size of the figure to show image.
|
||||
|
||||
Examples:
|
||||
>>> import mmcv
|
||||
>>> from mmcls.core import visualization as vis
|
||||
>>> img1 = mmcv.imread("./1.png")
|
||||
>>> info1 = {'class': 'cat', 'label': 0}
|
||||
>>> img2 = mmcv.imread("./2.png")
|
||||
>>> info2 = {'class': 'dog', 'label': 1}
|
||||
>>> with vis.ImshowInfosContextManager() as manager:
|
||||
... # Show img1
|
||||
... manager.put_img_infos(img1, info1)
|
||||
... # Show img2 on the same figure and save output image.
|
||||
... manager.put_img_infos(
|
||||
... img2, info2, out_file='./2_out.png')
|
||||
"""
|
||||
|
||||
def __init__(self, fig_size=(15, 10)):
|
||||
super().__init__(
|
||||
axis=False,
|
||||
# A proper dpi for image save with default font size.
|
||||
fig_save_cfg=dict(frameon=False, dpi=36),
|
||||
fig_show_cfg=dict(frameon=False, figsize=fig_size))
|
||||
|
||||
def _put_text(self, ax, text, x, y, text_color, font_size):
|
||||
ax.text(
|
||||
x,
|
||||
y,
|
||||
f'{text}',
|
||||
bbox={
|
||||
'facecolor': 'black',
|
||||
'alpha': 0.7,
|
||||
'pad': 0.2,
|
||||
'edgecolor': 'none',
|
||||
'boxstyle': 'round'
|
||||
},
|
||||
color=text_color,
|
||||
fontsize=font_size,
|
||||
family='monospace',
|
||||
verticalalignment='top',
|
||||
horizontalalignment='left')
|
||||
|
||||
def put_img_infos(self,
|
||||
img,
|
||||
infos,
|
||||
text_color='white',
|
||||
font_size=26,
|
||||
row_width=20,
|
||||
win_name='',
|
||||
show=True,
|
||||
wait_time=0,
|
||||
out_file=None):
|
||||
"""Show image with extra information.
|
||||
|
||||
Args:
|
||||
img (str | ndarray): The image to be displayed.
|
||||
infos (dict): Extra infos to display in the image.
|
||||
text_color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Extra infos
|
||||
display color. Defaults to 'white'.
|
||||
font_size (int): Extra infos display font size. Defaults to 26.
|
||||
row_width (int): width between each row of results on the image.
|
||||
win_name (str): The image title. Defaults to ''
|
||||
show (bool): Whether to show the image. Defaults to True.
|
||||
wait_time (int): How many seconds to display the image.
|
||||
Defaults to 0.
|
||||
out_file (Optional[str]): The filename to write the image.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image with extra infomations.
|
||||
"""
|
||||
self.prepare()
|
||||
|
||||
text_color = color_val_matplotlib(text_color)
|
||||
img = mmcv.imread(img).astype(np.uint8)
|
||||
|
||||
x, y = 3, row_width // 2
|
||||
img = mmcv.bgr2rgb(img)
|
||||
width, height = img.shape[1], img.shape[0]
|
||||
img = np.ascontiguousarray(img)
|
||||
|
||||
# add a small EPS to avoid precision lost due to matplotlib's
|
||||
# truncation (https://github.com/matplotlib/matplotlib/issues/15363)
|
||||
dpi = self.fig_save.get_dpi()
|
||||
self.fig_save.set_size_inches((width + EPS) / dpi,
|
||||
(height + EPS) / dpi)
|
||||
|
||||
for k, v in infos.items():
|
||||
if isinstance(v, float):
|
||||
v = f'{v:.2f}'
|
||||
label_text = f'{k}: {v}'
|
||||
self._put_text(self.ax_save, label_text, x, y, text_color,
|
||||
font_size)
|
||||
if show and not self.is_inline:
|
||||
self._put_text(self.ax_show, label_text, x, y, text_color,
|
||||
font_size)
|
||||
y += row_width
|
||||
|
||||
self.ax_save.imshow(img)
|
||||
stream, _ = self.fig_save.canvas.print_to_buffer()
|
||||
buffer = np.frombuffer(stream, dtype='uint8')
|
||||
img_rgba = buffer.reshape(height, width, 4)
|
||||
rgb, _ = np.split(img_rgba, [3], axis=2)
|
||||
img_save = rgb.astype('uint8')
|
||||
img_save = mmcv.rgb2bgr(img_save)
|
||||
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(img_save, out_file)
|
||||
|
||||
ret = 0
|
||||
if show and not self.is_inline:
|
||||
# Reserve some space for the tip.
|
||||
self.ax_show.set_title(win_name)
|
||||
self.ax_show.set_ylim(height + 20)
|
||||
self.ax_show.text(
|
||||
width // 2,
|
||||
height + 18,
|
||||
'Press SPACE to continue.',
|
||||
ha='center',
|
||||
fontsize=font_size)
|
||||
self.ax_show.imshow(img)
|
||||
|
||||
# Refresh canvas, necessary for Qt5 backend.
|
||||
self.fig_show.canvas.draw()
|
||||
|
||||
ret = self.wait_continue(timeout=wait_time)
|
||||
elif (not show) and self.is_inline:
|
||||
# If use inline backend, we use fig_save to show the image
|
||||
# So we need to close it if users don't want to show.
|
||||
plt.close(self.fig_save)
|
||||
|
||||
return ret, img_save
|
||||
|
||||
|
||||
def imshow_infos(img,
|
||||
infos,
|
||||
text_color='white',
|
||||
font_size=26,
|
||||
row_width=20,
|
||||
win_name='',
|
||||
show=True,
|
||||
fig_size=(15, 10),
|
||||
wait_time=0,
|
||||
out_file=None):
|
||||
"""Show image with extra information.
|
||||
|
||||
Args:
|
||||
img (str | ndarray): The image to be displayed.
|
||||
infos (dict): Extra infos to display in the image.
|
||||
text_color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Extra infos
|
||||
display color. Defaults to 'white'.
|
||||
font_size (int): Extra infos display font size. Defaults to 26.
|
||||
row_width (int): width between each row of results on the image.
|
||||
win_name (str): The image title. Defaults to ''
|
||||
show (bool): Whether to show the image. Defaults to True.
|
||||
fig_size (tuple): Image show figure size. Defaults to (15, 10).
|
||||
wait_time (int): How many seconds to display the image. Defaults to 0.
|
||||
out_file (Optional[str]): The filename to write the image.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image with extra infomations.
|
||||
"""
|
||||
with ImshowInfosContextManager(fig_size=fig_size) as manager:
|
||||
_, img = manager.put_img_infos(
|
||||
img,
|
||||
infos,
|
||||
text_color=text_color,
|
||||
font_size=font_size,
|
||||
row_width=row_width,
|
||||
win_name=win_name,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
out_file=out_file)
|
||||
return img
|
|
@ -3,13 +3,10 @@ from abc import ABCMeta, abstractmethod
|
|||
from collections import OrderedDict
|
||||
from typing import Sequence
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner import BaseModule, auto_fp16
|
||||
|
||||
from mmcls.core.visualization import imshow_infos
|
||||
|
||||
|
||||
class BaseClassifier(BaseModule, metaclass=ABCMeta):
|
||||
"""Base class for classifiers."""
|
||||
|
@ -174,51 +171,3 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
|
|||
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
|
||||
|
||||
return outputs
|
||||
|
||||
def show_result(self,
|
||||
img,
|
||||
result,
|
||||
text_color='white',
|
||||
font_scale=0.5,
|
||||
row_width=20,
|
||||
show=False,
|
||||
fig_size=(15, 10),
|
||||
win_name='',
|
||||
wait_time=0,
|
||||
out_file=None):
|
||||
"""Draw `result` over `img`.
|
||||
|
||||
Args:
|
||||
img (str or ndarray): The image to be displayed.
|
||||
result (dict): The classification results to draw over `img`.
|
||||
text_color (str or tuple or :obj:`Color`): Color of texts.
|
||||
font_scale (float): Font scales of texts.
|
||||
row_width (int): width between each row of results on the image.
|
||||
show (bool): Whether to show the image.
|
||||
Default: False.
|
||||
fig_size (tuple): Image show figure size. Defaults to (15, 10).
|
||||
win_name (str): The window name.
|
||||
wait_time (int): How many seconds to display the image.
|
||||
Defaults to 0.
|
||||
out_file (str or None): The filename to write the image.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
img (ndarray): Image with overlaid results.
|
||||
"""
|
||||
img = mmcv.imread(img)
|
||||
img = img.copy()
|
||||
|
||||
img = imshow_infos(
|
||||
img,
|
||||
result,
|
||||
text_color=text_color,
|
||||
font_size=int(font_scale * 50),
|
||||
row_width=row_width,
|
||||
win_name=win_name,
|
||||
show=show,
|
||||
fig_size=fig_size,
|
||||
wait_time=wait_time,
|
||||
out_file=out_file)
|
||||
|
||||
return img
|
||||
|
|
|
@ -45,21 +45,21 @@ DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
|
|||
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
|
||||
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
|
||||
|
||||
# mangage all kinds of modules inheriting `nn.Module`
|
||||
# manage all kinds of modules inheriting `nn.Module`
|
||||
MODELS = Registry('model', parent=MMENGINE_MODELS)
|
||||
# mangage all kinds of model wrappers like 'MMDistributedDataParallel'
|
||||
# manage all kinds of model wrappers like 'MMDistributedDataParallel'
|
||||
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
|
||||
# mangage all kinds of weight initialization modules like `Uniform`
|
||||
# manage all kinds of weight initialization modules like `Uniform`
|
||||
WEIGHT_INITIALIZERS = Registry(
|
||||
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
|
||||
|
||||
# Registries For Optimizer and the related
|
||||
# mangage all kinds of optimizers like `SGD` and `Adam`
|
||||
# manage all kinds of optimizers like `SGD` and `Adam`
|
||||
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
|
||||
# manage constructors that customize the optimization hyperparameters.
|
||||
OPTIMIZER_CONSTRUCTORS = Registry(
|
||||
'optimizer constructor', parent=MMENGINE_OPTIMIZER_CONSTRUCTORS)
|
||||
# mangage all kinds of parameter schedulers like `MultiStepLR`
|
||||
# manage all kinds of parameter schedulers like `MultiStepLR`
|
||||
PARAM_SCHEDULERS = Registry(
|
||||
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
|
||||
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmcls.core import ClsDataSample, ClsVisualizer
|
||||
|
||||
|
||||
class TestClsVisualizer(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
tmpdir = tempfile.TemporaryDirectory()
|
||||
self.tmpdir = tmpdir
|
||||
self.vis = ClsVisualizer(
|
||||
save_dir=tmpdir.name,
|
||||
vis_backends=[dict(type='LocalVisBackend')],
|
||||
)
|
||||
|
||||
def test_add_datasample(self):
|
||||
image = np.ones((10, 10, 3), np.uint8)
|
||||
data_sample = ClsDataSample().set_gt_label(1).set_pred_label(1).\
|
||||
set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
|
||||
|
||||
# Test show
|
||||
def mock_show(drawn_img, win_name, wait_time):
|
||||
self.assertFalse((image == drawn_img).all())
|
||||
self.assertEqual(win_name, 'test')
|
||||
self.assertEqual(wait_time, 0)
|
||||
|
||||
with patch.object(self.vis, 'show', mock_show):
|
||||
self.vis.add_datasample(
|
||||
'test', image=image, data_sample=data_sample, show=True)
|
||||
|
||||
# Test out_file
|
||||
out_file = osp.join(self.tmpdir.name, 'results.png')
|
||||
self.vis.add_datasample(
|
||||
'test', image=image, data_sample=data_sample, out_file=out_file)
|
||||
self.assertTrue(osp.exists(out_file))
|
||||
|
||||
# Test storage backend.
|
||||
save_file = osp.join(self.tmpdir.name, 'vis_data/vis_image/test_0.png')
|
||||
self.assertTrue(osp.exists(save_file))
|
||||
|
||||
# Test with dataset_meta
|
||||
self.vis.dataset_meta = {'CLASSES': ['cat', 'bird', 'dog']}
|
||||
|
||||
def test_texts(text, *_, **__):
|
||||
self.assertEqual(
|
||||
text, '\n'.join([
|
||||
'Ground truth: 1 (bird)',
|
||||
'Prediction: 1, 0.80 (bird)',
|
||||
]))
|
||||
|
||||
with patch.object(self.vis, 'draw_texts', test_texts):
|
||||
self.vis.add_datasample(
|
||||
'test', image=image, data_sample=data_sample)
|
||||
|
||||
# Test without pred_label
|
||||
def test_texts(text, *_, **__):
|
||||
self.assertEqual(text, '\n'.join([
|
||||
'Ground truth: 1 (bird)',
|
||||
]))
|
||||
|
||||
with patch.object(self.vis, 'draw_texts', test_texts):
|
||||
self.vis.add_datasample(
|
||||
'test', image=image, data_sample=data_sample, draw_pred=False)
|
||||
|
||||
# Test without gt_label
|
||||
def test_texts(text, *_, **__):
|
||||
self.assertEqual(text, '\n'.join([
|
||||
'Prediction: 1, 0.80 (bird)',
|
||||
]))
|
||||
|
||||
with patch.object(self.vis, 'draw_texts', test_texts):
|
||||
self.vis.add_datasample(
|
||||
'test', image=image, data_sample=data_sample, draw_gt=False)
|
||||
|
||||
# Test without score
|
||||
del data_sample.pred_label.score
|
||||
|
||||
def test_texts(text, *_, **__):
|
||||
self.assertEqual(
|
||||
text, '\n'.join([
|
||||
'Ground truth: 1 (bird)',
|
||||
'Prediction: 1 (bird)',
|
||||
]))
|
||||
|
||||
with patch.object(self.vis, 'draw_texts', test_texts):
|
||||
self.vis.add_datasample(
|
||||
'test', image=image, data_sample=data_sample)
|
||||
|
||||
def tearDown(self):
|
||||
self.tmpdir.cleanup()
|
|
@ -1,118 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
from mmcv import Config
|
||||
from mmdet.apis import inference_detector
|
||||
from mmdet.models import build_detector
|
||||
|
||||
from mmcls.models import (MobileNetV2, MobileNetV3, RegNet, ResNeSt, ResNet,
|
||||
ResNeXt, SEResNet, SEResNeXt, SwinTransformer,
|
||||
TIMMBackbone)
|
||||
from mmcls.models.backbones.timm_backbone import timm
|
||||
|
||||
backbone_configs = dict(
|
||||
mobilenetv2=dict(
|
||||
backbone=dict(
|
||||
type='mmcls.MobileNetV2',
|
||||
widen_factor=1.0,
|
||||
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True),
|
||||
out_indices=(4, 7)),
|
||||
out_channels=[96, 1280]),
|
||||
mobilenetv3=dict(
|
||||
backbone=dict(
|
||||
type='mmcls.MobileNetV3',
|
||||
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True),
|
||||
out_indices=range(7, 12)),
|
||||
out_channels=[48, 48, 96, 96, 96]),
|
||||
regnet=dict(
|
||||
backbone=dict(type='mmcls.RegNet', arch='regnetx_400mf'),
|
||||
out_channels=384),
|
||||
resnext=dict(
|
||||
backbone=dict(
|
||||
type='mmcls.ResNeXt', depth=50, groups=32, width_per_group=4),
|
||||
out_channels=2048),
|
||||
resnet=dict(
|
||||
backbone=dict(type='mmcls.ResNet', depth=50), out_channels=2048),
|
||||
seresnet=dict(
|
||||
backbone=dict(type='mmcls.SEResNet', depth=50), out_channels=2048),
|
||||
seresnext=dict(
|
||||
backbone=dict(
|
||||
type='mmcls.SEResNeXt', depth=50, groups=32, width_per_group=4),
|
||||
out_channels=2048),
|
||||
resnest=dict(
|
||||
backbone=dict(
|
||||
type='mmcls.ResNeSt',
|
||||
depth=50,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
out_indices=(0, 1, 2, 3)),
|
||||
out_channels=[256, 512, 1024, 2048]),
|
||||
swin=dict(
|
||||
backbone=dict(
|
||||
type='mmcls.SwinTransformer',
|
||||
arch='small',
|
||||
drop_path_rate=0.2,
|
||||
img_size=800,
|
||||
out_indices=(2, 3)),
|
||||
out_channels=[384, 768]),
|
||||
timm_efficientnet=dict(
|
||||
backbone=dict(
|
||||
type='mmcls.TIMMBackbone',
|
||||
model_name='efficientnet_b1',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
out_indices=(1, 2, 3, 4)),
|
||||
out_channels=[24, 40, 112, 320]),
|
||||
timm_resnet=dict(
|
||||
backbone=dict(
|
||||
type='mmcls.TIMMBackbone',
|
||||
model_name='resnet50',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
out_indices=(1, 2, 3, 4)),
|
||||
out_channels=[256, 512, 1024, 2048]))
|
||||
|
||||
module_mapping = {
|
||||
'mobilenetv2': MobileNetV2,
|
||||
'mobilenetv3': MobileNetV3,
|
||||
'regnet': RegNet,
|
||||
'resnext': ResNeXt,
|
||||
'resnet': ResNet,
|
||||
'seresnext': SEResNeXt,
|
||||
'seresnet': SEResNet,
|
||||
'resnest': ResNeSt,
|
||||
'swin': SwinTransformer,
|
||||
'timm_efficientnet': TIMMBackbone,
|
||||
'timm_resnet': TIMMBackbone
|
||||
}
|
||||
|
||||
|
||||
def test_mmdet_inference():
|
||||
config_path = './tests/data/retinanet.py'
|
||||
rng = np.random.RandomState(0)
|
||||
img1 = rng.rand(100, 100, 3)
|
||||
|
||||
for module_name, backbone_config in backbone_configs.items():
|
||||
module = module_mapping[module_name]
|
||||
if module is TIMMBackbone and timm is None:
|
||||
print(f'skip {module_name} because timm is not available')
|
||||
continue
|
||||
print(f'test {module_name}')
|
||||
config = Config.fromfile(config_path)
|
||||
config.model.backbone = backbone_config['backbone']
|
||||
out_channels = backbone_config['out_channels']
|
||||
if isinstance(out_channels, int):
|
||||
config.model.neck = None
|
||||
config.model.bbox_head.in_channels = out_channels
|
||||
anchor_generator = config.model.bbox_head.anchor_generator
|
||||
anchor_generator.strides = anchor_generator.strides[:1]
|
||||
else:
|
||||
config.model.neck.in_channels = out_channels
|
||||
|
||||
model = build_detector(config.model)
|
||||
assert isinstance(model.backbone, module)
|
||||
|
||||
model.cfg = config
|
||||
|
||||
model.eval()
|
||||
result = inference_detector(model, img1)
|
||||
assert len(result) == config.num_classes
|
|
@ -1,9 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv import ConfigDict
|
||||
|
||||
|
@ -90,20 +87,6 @@ def test_image_classifier():
|
|||
model = CLASSIFIERS.build(model_cfg_)
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint='checkpoint')
|
||||
|
||||
# test show_result
|
||||
img = np.random.randint(0, 256, (224, 224, 3)).astype(np.uint8)
|
||||
result = dict(pred_class='cat', pred_label=0, pred_score=0.9)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
out_file = osp.join(tmpdir, 'out.png')
|
||||
model.show_result(img, result, out_file=out_file)
|
||||
assert osp.exists(out_file)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
out_file = osp.join(tmpdir, 'out.png')
|
||||
model.show_result(img, result, out_file=out_file)
|
||||
assert osp.exists(out_file)
|
||||
|
||||
|
||||
def test_image_classifier_with_mixup():
|
||||
# Test mixup in ImageClassifier
|
||||
|
|
|
@ -1,158 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import build_runner
|
||||
from mmengine.hooks import Hook, IterTimerHook
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import mmcls.core # noqa: F401
|
||||
|
||||
|
||||
def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',
|
||||
max_epochs=1,
|
||||
max_iters=None,
|
||||
multi_optimziers=False):
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(2, 1)
|
||||
self.conv = nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
def train_step(self, x, optimizer, **kwargs):
|
||||
return dict(loss=self(x))
|
||||
|
||||
def val_step(self, x, optimizer, **kwargs):
|
||||
return dict(loss=self(x))
|
||||
|
||||
model = Model()
|
||||
|
||||
if multi_optimziers:
|
||||
optimizer = {
|
||||
'model1':
|
||||
torch.optim.SGD(model.linear.parameters(), lr=0.02, momentum=0.95),
|
||||
'model2':
|
||||
torch.optim.SGD(model.conv.parameters(), lr=0.01, momentum=0.9),
|
||||
}
|
||||
else:
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
|
||||
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
runner = build_runner(
|
||||
dict(type=runner_type),
|
||||
default_args=dict(
|
||||
model=model,
|
||||
work_dir=tmp_dir,
|
||||
optimizer=optimizer,
|
||||
logger=logging.getLogger(),
|
||||
max_epochs=max_epochs,
|
||||
max_iters=max_iters))
|
||||
return runner
|
||||
|
||||
|
||||
def _build_demo_runner(runner_type='EpochBasedRunner',
|
||||
max_epochs=1,
|
||||
max_iters=None,
|
||||
multi_optimziers=False):
|
||||
|
||||
log_config = dict(
|
||||
interval=1, hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
])
|
||||
|
||||
runner = _build_demo_runner_without_hook(runner_type, max_epochs,
|
||||
max_iters, multi_optimziers)
|
||||
|
||||
runner.register_checkpoint_hook(dict(interval=1))
|
||||
runner.register_logger_hooks(log_config)
|
||||
return runner
|
||||
|
||||
|
||||
class ValueCheckHook(Hook):
|
||||
|
||||
def __init__(self, check_dict, by_epoch=False):
|
||||
super().__init__()
|
||||
self.check_dict = check_dict
|
||||
self.by_epoch = by_epoch
|
||||
|
||||
def after_iter(self, runner):
|
||||
if self.by_epoch:
|
||||
return
|
||||
if runner.iter in self.check_dict:
|
||||
for attr, target in self.check_dict[runner.iter].items():
|
||||
value = eval(f'runner.{attr}')
|
||||
assert np.isclose(value, target), \
|
||||
(f'The value of `runner.{attr}` is {value}, '
|
||||
f'not equals to {target}')
|
||||
|
||||
def after_epoch(self, runner):
|
||||
if not self.by_epoch:
|
||||
return
|
||||
if runner.epoch in self.check_dict:
|
||||
for attr, target in self.check_dict[runner.epoch]:
|
||||
value = eval(f'runner.{attr}')
|
||||
assert np.isclose(value, target), \
|
||||
(f'The value of `runner.{attr}` is {value}, '
|
||||
f'not equals to {target}')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('multi_optimziers', (True, False))
|
||||
def test_cosine_cooldown_hook(multi_optimziers):
|
||||
"""xdoctest -m tests/test_hooks.py test_cosine_runner_hook."""
|
||||
loader = DataLoader(torch.ones((10, 2)))
|
||||
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
|
||||
|
||||
# add momentum LR scheduler
|
||||
hook_cfg = dict(
|
||||
type='CosineAnnealingCooldownLrUpdaterHook',
|
||||
by_epoch=False,
|
||||
cool_down_time=2,
|
||||
cool_down_ratio=0.1,
|
||||
min_lr_ratio=0.1,
|
||||
warmup_iters=2,
|
||||
warmup_ratio=0.9)
|
||||
runner.register_hook_from_cfg(hook_cfg)
|
||||
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
|
||||
runner.register_hook(IterTimerHook())
|
||||
|
||||
if multi_optimziers:
|
||||
check_hook = ValueCheckHook({
|
||||
0: {
|
||||
'current_lr()["model1"][0]': 0.02,
|
||||
'current_lr()["model2"][0]': 0.01,
|
||||
},
|
||||
5: {
|
||||
'current_lr()["model1"][0]': 0.0075558491,
|
||||
'current_lr()["model2"][0]': 0.0037779246,
|
||||
},
|
||||
9: {
|
||||
'current_lr()["model1"][0]': 0.0002,
|
||||
'current_lr()["model2"][0]': 0.0001,
|
||||
}
|
||||
})
|
||||
else:
|
||||
check_hook = ValueCheckHook({
|
||||
0: {
|
||||
'current_lr()[0]': 0.02,
|
||||
},
|
||||
5: {
|
||||
'current_lr()[0]': 0.0075558491,
|
||||
},
|
||||
9: {
|
||||
'current_lr()[0]': 0.0002,
|
||||
}
|
||||
})
|
||||
runner.register_hook(check_hook, priority='LOWEST')
|
||||
|
||||
runner.run([loader], [('train', 1)])
|
||||
shutil.rmtree(runner.work_dir)
|
|
@ -1,84 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import mmcv.runner as mmcv_runner
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.runner import obj_from_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmcls.core.hook import ClassNumCheckHook
|
||||
from mmcls.models.heads.base_head import BaseHead
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __init__(self, CLASSES):
|
||||
self.CLASSES = CLASSES
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(img=torch.tensor([1]), img_metas=dict())
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
|
||||
class ExampleHead(BaseHead):
|
||||
|
||||
def __init__(self, init_cfg=None):
|
||||
super(BaseHead, self).__init__(init_cfg)
|
||||
self.num_classes = 4
|
||||
|
||||
def forward_train(self, x, gt_label=None, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ExampleModel(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleModel, self).__init__()
|
||||
self.test_cfg = None
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
self.head = ExampleHead()
|
||||
|
||||
def forward(self, img, img_metas, test_mode=False, **kwargs):
|
||||
return img
|
||||
|
||||
def train_step(self, data_batch, optimizer):
|
||||
loss = self.forward(**data_batch)
|
||||
return dict(loss=loss)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('runner_type',
|
||||
['EpochBasedRunner', 'IterBasedRunner'])
|
||||
@pytest.mark.parametrize(
|
||||
'CLASSES', [None, ('A', 'B', 'C', 'D', 'E'), ('A', 'B', 'C', 'D')])
|
||||
def test_num_class_hook(runner_type, CLASSES):
|
||||
test_dataset = ExampleDataset(CLASSES)
|
||||
loader = DataLoader(test_dataset, batch_size=1)
|
||||
model = ExampleModel()
|
||||
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||
dict(params=model.parameters()))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
num_class_hook = ClassNumCheckHook()
|
||||
logger_mock = MagicMock(spec=logging.Logger)
|
||||
runner = getattr(mmcv_runner, runner_type)(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=tmpdir,
|
||||
logger=logger_mock,
|
||||
max_epochs=1)
|
||||
runner.register_hook(num_class_hook)
|
||||
if CLASSES is None:
|
||||
runner.run([loader], [('train', 1)], 1)
|
||||
logger_mock.warning.assert_called()
|
||||
elif len(CLASSES) != 4:
|
||||
with pytest.raises(AssertionError):
|
||||
runner.run([loader], [('train', 1)], 1)
|
||||
else:
|
||||
runner.run([loader], [('train', 1)], 1)
|
|
@ -1,309 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import build_optimizer
|
||||
from mmcv.runner.optimizer.builder import OPTIMIZERS
|
||||
from mmcv.utils.registry import build_from_cfg
|
||||
from torch.autograd import Variable
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
import mmcls.core # noqa: F401
|
||||
|
||||
base_lr = 0.01
|
||||
base_wd = 0.0001
|
||||
|
||||
|
||||
def assert_equal(x, y):
|
||||
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
||||
torch.testing.assert_allclose(x, y.to(x.device))
|
||||
elif isinstance(x, OrderedDict) and isinstance(y, OrderedDict):
|
||||
for x_value, y_value in zip(x.values(), y.values()):
|
||||
assert_equal(x_value, y_value)
|
||||
elif isinstance(x, dict) and isinstance(y, dict):
|
||||
assert x.keys() == y.keys()
|
||||
for key in x.keys():
|
||||
assert_equal(x[key], y[key])
|
||||
elif isinstance(x, str) and isinstance(y, str):
|
||||
assert x == y
|
||||
elif isinstance(x, Iterable) and isinstance(y, Iterable):
|
||||
assert len(x) == len(y)
|
||||
for x_item, y_item in zip(x, y):
|
||||
assert_equal(x_item, y_item)
|
||||
else:
|
||||
assert x == y
|
||||
|
||||
|
||||
class SubModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(2, 2, kernel_size=1, groups=2)
|
||||
self.gn = nn.GroupNorm(2, 2)
|
||||
self.fc = nn.Linear(2, 2)
|
||||
self.param1 = nn.Parameter(torch.ones(1))
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param1 = nn.Parameter(torch.ones(1))
|
||||
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
|
||||
self.bn = nn.BatchNorm2d(2)
|
||||
self.sub = SubModel()
|
||||
self.fc = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def check_lamb_optimizer(optimizer,
|
||||
model,
|
||||
bias_lr_mult=1,
|
||||
bias_decay_mult=1,
|
||||
norm_decay_mult=1,
|
||||
dwconv_decay_mult=1):
|
||||
param_groups = optimizer.param_groups
|
||||
assert isinstance(optimizer, Optimizer)
|
||||
assert optimizer.defaults['lr'] == base_lr
|
||||
assert optimizer.defaults['weight_decay'] == base_wd
|
||||
model_parameters = list(model.parameters())
|
||||
assert len(param_groups) == len(model_parameters)
|
||||
for i, param in enumerate(model_parameters):
|
||||
param_group = param_groups[i]
|
||||
assert torch.equal(param_group['params'][0], param)
|
||||
# param1
|
||||
param1 = param_groups[0]
|
||||
assert param1['lr'] == base_lr
|
||||
assert param1['weight_decay'] == base_wd
|
||||
# conv1.weight
|
||||
conv1_weight = param_groups[1]
|
||||
assert conv1_weight['lr'] == base_lr
|
||||
assert conv1_weight['weight_decay'] == base_wd
|
||||
# conv2.weight
|
||||
conv2_weight = param_groups[2]
|
||||
assert conv2_weight['lr'] == base_lr
|
||||
assert conv2_weight['weight_decay'] == base_wd
|
||||
# conv2.bias
|
||||
conv2_bias = param_groups[3]
|
||||
assert conv2_bias['lr'] == base_lr * bias_lr_mult
|
||||
assert conv2_bias['weight_decay'] == base_wd * bias_decay_mult
|
||||
# bn.weight
|
||||
bn_weight = param_groups[4]
|
||||
assert bn_weight['lr'] == base_lr
|
||||
assert bn_weight['weight_decay'] == base_wd * norm_decay_mult
|
||||
# bn.bias
|
||||
bn_bias = param_groups[5]
|
||||
assert bn_bias['lr'] == base_lr
|
||||
assert bn_bias['weight_decay'] == base_wd * norm_decay_mult
|
||||
# sub.param1
|
||||
sub_param1 = param_groups[6]
|
||||
assert sub_param1['lr'] == base_lr
|
||||
assert sub_param1['weight_decay'] == base_wd
|
||||
# sub.conv1.weight
|
||||
sub_conv1_weight = param_groups[7]
|
||||
assert sub_conv1_weight['lr'] == base_lr
|
||||
assert sub_conv1_weight['weight_decay'] == base_wd * dwconv_decay_mult
|
||||
# sub.conv1.bias
|
||||
sub_conv1_bias = param_groups[8]
|
||||
assert sub_conv1_bias['lr'] == base_lr * bias_lr_mult
|
||||
assert sub_conv1_bias['weight_decay'] == base_wd * dwconv_decay_mult
|
||||
# sub.gn.weight
|
||||
sub_gn_weight = param_groups[9]
|
||||
assert sub_gn_weight['lr'] == base_lr
|
||||
assert sub_gn_weight['weight_decay'] == base_wd * norm_decay_mult
|
||||
# sub.gn.bias
|
||||
sub_gn_bias = param_groups[10]
|
||||
assert sub_gn_bias['lr'] == base_lr
|
||||
assert sub_gn_bias['weight_decay'] == base_wd * norm_decay_mult
|
||||
# sub.fc1.weight
|
||||
sub_fc_weight = param_groups[11]
|
||||
assert sub_fc_weight['lr'] == base_lr
|
||||
assert sub_fc_weight['weight_decay'] == base_wd
|
||||
# sub.fc1.bias
|
||||
sub_fc_bias = param_groups[12]
|
||||
assert sub_fc_bias['lr'] == base_lr * bias_lr_mult
|
||||
assert sub_fc_bias['weight_decay'] == base_wd * bias_decay_mult
|
||||
# fc1.weight
|
||||
fc_weight = param_groups[13]
|
||||
assert fc_weight['lr'] == base_lr
|
||||
assert fc_weight['weight_decay'] == base_wd
|
||||
# fc1.bias
|
||||
fc_bias = param_groups[14]
|
||||
assert fc_bias['lr'] == base_lr * bias_lr_mult
|
||||
assert fc_bias['weight_decay'] == base_wd * bias_decay_mult
|
||||
|
||||
|
||||
def _test_state_dict(weight, bias, input, constructor):
|
||||
weight = Variable(weight, requires_grad=True)
|
||||
bias = Variable(bias, requires_grad=True)
|
||||
inputs = Variable(input)
|
||||
|
||||
def fn_base(optimizer, weight, bias):
|
||||
optimizer.zero_grad()
|
||||
i = input_cuda if weight.is_cuda else inputs
|
||||
loss = (weight.mv(i) + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
optimizer = constructor(weight, bias)
|
||||
fn = functools.partial(fn_base, optimizer, weight, bias)
|
||||
|
||||
# Prime the optimizer
|
||||
for _ in range(20):
|
||||
optimizer.step(fn)
|
||||
# Clone the weights and construct new optimizer for them
|
||||
weight_c = Variable(weight.data.clone(), requires_grad=True)
|
||||
bias_c = Variable(bias.data.clone(), requires_grad=True)
|
||||
optimizer_c = constructor(weight_c, bias_c)
|
||||
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
|
||||
# Load state dict
|
||||
state_dict = deepcopy(optimizer.state_dict())
|
||||
state_dict_c = deepcopy(optimizer.state_dict())
|
||||
optimizer_c.load_state_dict(state_dict_c)
|
||||
# Run both optimizations in parallel
|
||||
for _ in range(20):
|
||||
optimizer.step(fn)
|
||||
optimizer_c.step(fn_c)
|
||||
assert_equal(weight, weight_c)
|
||||
assert_equal(bias, bias_c)
|
||||
# Make sure state dict wasn't modified
|
||||
assert_equal(state_dict, state_dict_c)
|
||||
# Make sure state dict is deterministic with equal
|
||||
# but not identical parameters
|
||||
# NOTE: The state_dict of optimizers in PyTorch 1.5 have random keys,
|
||||
state_dict = deepcopy(optimizer.state_dict())
|
||||
state_dict_c = deepcopy(optimizer_c.state_dict())
|
||||
keys = state_dict['param_groups'][-1]['params']
|
||||
keys_c = state_dict_c['param_groups'][-1]['params']
|
||||
for key, key_c in zip(keys, keys_c):
|
||||
assert_equal(optimizer.state_dict()['state'][key],
|
||||
optimizer_c.state_dict()['state'][key_c])
|
||||
# Make sure repeated parameters have identical representation in state dict
|
||||
optimizer_c.param_groups.extend(optimizer_c.param_groups)
|
||||
assert_equal(optimizer_c.state_dict()['param_groups'][0],
|
||||
optimizer_c.state_dict()['param_groups'][1])
|
||||
|
||||
# Check that state dict can be loaded even when we cast parameters
|
||||
# to a different type and move to a different device.
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
input_cuda = Variable(inputs.data.float().cuda())
|
||||
weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
|
||||
bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
|
||||
optimizer_cuda = constructor(weight_cuda, bias_cuda)
|
||||
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda,
|
||||
bias_cuda)
|
||||
|
||||
state_dict = deepcopy(optimizer.state_dict())
|
||||
state_dict_c = deepcopy(optimizer.state_dict())
|
||||
optimizer_cuda.load_state_dict(state_dict_c)
|
||||
|
||||
# Make sure state dict wasn't modified
|
||||
assert_equal(state_dict, state_dict_c)
|
||||
|
||||
for _ in range(20):
|
||||
optimizer.step(fn)
|
||||
optimizer_cuda.step(fn_cuda)
|
||||
assert_equal(weight, weight_cuda)
|
||||
assert_equal(bias, bias_cuda)
|
||||
|
||||
# validate deepcopy() copies all public attributes
|
||||
def getPublicAttr(obj):
|
||||
return set(k for k in obj.__dict__ if not k.startswith('_'))
|
||||
|
||||
assert_equal(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))
|
||||
|
||||
|
||||
def _test_basic_cases_template(weight, bias, inputs, constructor,
|
||||
scheduler_constructors):
|
||||
"""Copied from PyTorch."""
|
||||
weight = Variable(weight, requires_grad=True)
|
||||
bias = Variable(bias, requires_grad=True)
|
||||
inputs = Variable(inputs)
|
||||
optimizer = constructor(weight, bias)
|
||||
schedulers = []
|
||||
for scheduler_constructor in scheduler_constructors:
|
||||
schedulers.append(scheduler_constructor(optimizer))
|
||||
|
||||
# to check if the optimizer can be printed as a string
|
||||
optimizer.__repr__()
|
||||
|
||||
def fn():
|
||||
optimizer.zero_grad()
|
||||
y = weight.mv(inputs)
|
||||
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
|
||||
y = y.cuda(bias.get_device())
|
||||
loss = (y + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
initial_value = fn().item()
|
||||
for _ in range(200):
|
||||
for scheduler in schedulers:
|
||||
scheduler.step()
|
||||
optimizer.step(fn)
|
||||
|
||||
assert fn().item() < initial_value
|
||||
|
||||
|
||||
def _test_basic_cases(constructor,
|
||||
scheduler_constructors=None,
|
||||
ignore_multidevice=False):
|
||||
"""Copied from PyTorch."""
|
||||
if scheduler_constructors is None:
|
||||
scheduler_constructors = []
|
||||
_test_state_dict(
|
||||
torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor)
|
||||
_test_basic_cases_template(
|
||||
torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor,
|
||||
scheduler_constructors)
|
||||
# non-contiguous parameters
|
||||
_test_basic_cases_template(
|
||||
torch.randn(10, 5, 2)[..., 0],
|
||||
torch.randn(10, 2)[..., 0], torch.randn(5), constructor,
|
||||
scheduler_constructors)
|
||||
# CUDA
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
_test_basic_cases_template(
|
||||
torch.randn(10, 5).cuda(),
|
||||
torch.randn(10).cuda(),
|
||||
torch.randn(5).cuda(), constructor, scheduler_constructors)
|
||||
# Multi-GPU
|
||||
if not torch.cuda.device_count() > 1 or ignore_multidevice:
|
||||
return
|
||||
_test_basic_cases_template(
|
||||
torch.randn(10, 5).cuda(0),
|
||||
torch.randn(10).cuda(1),
|
||||
torch.randn(5).cuda(0), constructor, scheduler_constructors)
|
||||
|
||||
|
||||
def test_lamb_optimizer():
|
||||
model = ExampleModel()
|
||||
optimizer_cfg = dict(
|
||||
type='Lamb',
|
||||
lr=base_lr,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=base_wd,
|
||||
paramwise_cfg=dict(
|
||||
bias_lr_mult=2,
|
||||
bias_decay_mult=0.5,
|
||||
norm_decay_mult=0,
|
||||
dwconv_decay_mult=0.1))
|
||||
optimizer = build_optimizer(model, optimizer_cfg)
|
||||
check_lamb_optimizer(optimizer, model, **optimizer_cfg['paramwise_cfg'])
|
||||
|
||||
_test_basic_cases(lambda weight, bias: build_from_cfg(
|
||||
dict(type='Lamb', params=[weight, bias], lr=base_lr), OPTIMIZERS))
|
|
@ -1,264 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import EpochBasedRunner, IterBasedRunner, build_optimizer
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmcls.core.hook import PreciseBNHook
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __init__(self):
|
||||
self.index = 0
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(imgs=torch.tensor([1.0], dtype=torch.float32))
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
|
||||
class BiggerDataset(ExampleDataset):
|
||||
|
||||
def __init__(self, fixed_values=range(0, 12)):
|
||||
assert len(self) == len(fixed_values)
|
||||
self.fixed_values = fixed_values
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(
|
||||
imgs=torch.tensor([self.fixed_values[idx]], dtype=torch.float32))
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
# a bigger dataset
|
||||
return 12
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Linear(1, 1)
|
||||
self.bn = nn.BatchNorm1d(1)
|
||||
self.test_cfg = None
|
||||
|
||||
def forward(self, imgs, return_loss=False):
|
||||
return self.bn(self.conv(imgs))
|
||||
|
||||
def train_step(self, data_batch, optimizer, **kwargs):
|
||||
outputs = {
|
||||
'loss': 0.5,
|
||||
'log_vars': {
|
||||
'accuracy': 0.98
|
||||
},
|
||||
'num_samples': 1
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
||||
class SingleBNModel(ExampleModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bn = nn.BatchNorm1d(1)
|
||||
self.test_cfg = None
|
||||
|
||||
def forward(self, imgs, return_loss=False):
|
||||
return self.bn(imgs)
|
||||
|
||||
|
||||
class GNExampleModel(ExampleModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Linear(1, 1)
|
||||
self.bn = nn.GroupNorm(1, 1)
|
||||
self.test_cfg = None
|
||||
|
||||
|
||||
class NoBNExampleModel(ExampleModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Linear(1, 1)
|
||||
self.test_cfg = None
|
||||
|
||||
def forward(self, imgs, return_loss=False):
|
||||
return self.conv(imgs)
|
||||
|
||||
|
||||
def test_precise_bn():
|
||||
optimizer_cfg = dict(
|
||||
type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
|
||||
|
||||
test_dataset = ExampleDataset()
|
||||
loader = DataLoader(test_dataset, batch_size=2)
|
||||
model = ExampleModel()
|
||||
optimizer = build_optimizer(model, optimizer_cfg)
|
||||
logger = MMLogger.get_instance('precise_bn')
|
||||
runner = EpochBasedRunner(
|
||||
model=model,
|
||||
batch_processor=None,
|
||||
optimizer=optimizer,
|
||||
logger=logger,
|
||||
max_epochs=1)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# num_samples must be larger than 0
|
||||
precise_bn_hook = PreciseBNHook(num_samples=-1)
|
||||
runner.register_hook(precise_bn_hook)
|
||||
runner.run([loader], [('train', 1)])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# interval must be larger than 0
|
||||
precise_bn_hook = PreciseBNHook(interval=0)
|
||||
runner.register_hook(precise_bn_hook)
|
||||
runner.run([loader], [('train', 1)])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# interval must be larger than 0
|
||||
runner = EpochBasedRunner(
|
||||
model=model,
|
||||
batch_processor=None,
|
||||
optimizer=optimizer,
|
||||
logger=logger,
|
||||
max_epochs=1)
|
||||
precise_bn_hook = PreciseBNHook(interval=0)
|
||||
runner.register_hook(precise_bn_hook)
|
||||
runner.run([loader], [('train', 1)])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# only support EpochBaseRunner
|
||||
runner = IterBasedRunner(
|
||||
model=model,
|
||||
batch_processor=None,
|
||||
optimizer=optimizer,
|
||||
logger=logger,
|
||||
max_epochs=1)
|
||||
precise_bn_hook = PreciseBNHook(interval=2)
|
||||
runner.register_hook(precise_bn_hook)
|
||||
print_log(runner)
|
||||
runner.run([loader], [('train', 1)])
|
||||
|
||||
# test non-DDP model
|
||||
test_bigger_dataset = BiggerDataset()
|
||||
loader = DataLoader(test_bigger_dataset, batch_size=2)
|
||||
loaders = [loader]
|
||||
precise_bn_hook = PreciseBNHook(num_samples=4)
|
||||
assert precise_bn_hook.num_samples == 4
|
||||
assert precise_bn_hook.interval == 1
|
||||
runner = EpochBasedRunner(
|
||||
model=model,
|
||||
batch_processor=None,
|
||||
optimizer=optimizer,
|
||||
logger=logger,
|
||||
max_epochs=1)
|
||||
runner.register_hook(precise_bn_hook)
|
||||
runner.run(loaders, [('train', 1)])
|
||||
|
||||
# test DP model
|
||||
test_bigger_dataset = BiggerDataset()
|
||||
loader = DataLoader(test_bigger_dataset, batch_size=2)
|
||||
loaders = [loader]
|
||||
precise_bn_hook = PreciseBNHook(num_samples=4)
|
||||
assert precise_bn_hook.num_samples == 4
|
||||
assert precise_bn_hook.interval == 1
|
||||
model = MMDataParallel(model)
|
||||
runner = EpochBasedRunner(
|
||||
model=model,
|
||||
batch_processor=None,
|
||||
optimizer=optimizer,
|
||||
logger=logger,
|
||||
max_epochs=1)
|
||||
runner.register_hook(precise_bn_hook)
|
||||
runner.run(loaders, [('train', 1)])
|
||||
|
||||
# test model w/ gn layer
|
||||
loader = DataLoader(test_bigger_dataset, batch_size=2)
|
||||
loaders = [loader]
|
||||
precise_bn_hook = PreciseBNHook(num_samples=4)
|
||||
assert precise_bn_hook.num_samples == 4
|
||||
assert precise_bn_hook.interval == 1
|
||||
model = GNExampleModel()
|
||||
runner = EpochBasedRunner(
|
||||
model=model,
|
||||
batch_processor=None,
|
||||
optimizer=optimizer,
|
||||
logger=logger,
|
||||
max_epochs=1)
|
||||
runner.register_hook(precise_bn_hook)
|
||||
runner.run(loaders, [('train', 1)])
|
||||
|
||||
# test model without bn layer
|
||||
loader = DataLoader(test_bigger_dataset, batch_size=2)
|
||||
loaders = [loader]
|
||||
precise_bn_hook = PreciseBNHook(num_samples=4)
|
||||
assert precise_bn_hook.num_samples == 4
|
||||
assert precise_bn_hook.interval == 1
|
||||
model = NoBNExampleModel()
|
||||
runner = EpochBasedRunner(
|
||||
model=model,
|
||||
batch_processor=None,
|
||||
optimizer=optimizer,
|
||||
logger=logger,
|
||||
max_epochs=1)
|
||||
runner.register_hook(precise_bn_hook)
|
||||
runner.run(loaders, [('train', 1)])
|
||||
|
||||
# test how precise it is
|
||||
loader = DataLoader(test_bigger_dataset, batch_size=2)
|
||||
loaders = [loader]
|
||||
precise_bn_hook = PreciseBNHook(num_samples=12)
|
||||
assert precise_bn_hook.num_samples == 12
|
||||
assert precise_bn_hook.interval == 1
|
||||
model = SingleBNModel()
|
||||
runner = EpochBasedRunner(
|
||||
model=model,
|
||||
batch_processor=None,
|
||||
optimizer=optimizer,
|
||||
logger=logger,
|
||||
max_epochs=1)
|
||||
runner.register_hook(precise_bn_hook)
|
||||
runner.run(loaders, [('train', 1)])
|
||||
imgs_list = list()
|
||||
for loader in loaders:
|
||||
for i, data in enumerate(loader):
|
||||
imgs_list.append(np.array(data['imgs']))
|
||||
mean = np.mean([np.mean(batch) for batch in imgs_list])
|
||||
# bassel correction used in Pytorch, therefore ddof=1
|
||||
var = np.mean([np.var(batch, ddof=1) for batch in imgs_list])
|
||||
assert np.equal(mean, np.array(
|
||||
model.bn.running_mean)), (mean, np.array(model.bn.running_mean))
|
||||
assert np.equal(var, np.array(
|
||||
model.bn.running_var)), (var, np.array(model.bn.running_var))
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_ddp_model_precise_bn():
|
||||
# test DDP model
|
||||
test_bigger_dataset = BiggerDataset()
|
||||
loader = DataLoader(test_bigger_dataset, batch_size=2)
|
||||
loaders = [loader]
|
||||
precise_bn_hook = PreciseBNHook(num_samples=5)
|
||||
assert precise_bn_hook.num_samples == 5
|
||||
assert precise_bn_hook.interval == 1
|
||||
model = ExampleModel()
|
||||
model = MMDistributedDataParallel(
|
||||
model.cuda(),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
broadcast_buffers=False,
|
||||
find_unused_parameters=True)
|
||||
runner = EpochBasedRunner(
|
||||
model=model,
|
||||
batch_processor=None,
|
||||
optimizer=optimizer,
|
||||
logger=logger,
|
||||
max_epochs=1)
|
||||
runner.register_hook(precise_bn_hook)
|
||||
runner.run(loaders, [('train', 1)])
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
from mmengine.logging import MMLogger
|
||||
|
||||
from mmcls.utils import get_root_logger, load_json_log
|
||||
|
||||
|
||||
|
|
|
@ -1,100 +0,0 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mmcls.core import visualization as vis
|
||||
|
||||
|
||||
def test_color():
|
||||
assert vis.color_val_matplotlib(mmcv.Color.blue) == (0., 0., 1.)
|
||||
assert vis.color_val_matplotlib('green') == (0., 1., 0.)
|
||||
assert vis.color_val_matplotlib((1, 2, 3)) == (3 / 255, 2 / 255, 1 / 255)
|
||||
assert vis.color_val_matplotlib(100) == (100 / 255, 100 / 255, 100 / 255)
|
||||
assert vis.color_val_matplotlib(np.zeros(3, dtype=int)) == (0., 0., 0.)
|
||||
# forbid white color
|
||||
with pytest.raises(TypeError):
|
||||
vis.color_val_matplotlib([255, 255, 255])
|
||||
# forbid float
|
||||
with pytest.raises(TypeError):
|
||||
vis.color_val_matplotlib(1.0)
|
||||
# overflowed
|
||||
with pytest.raises(AssertionError):
|
||||
vis.color_val_matplotlib((0, 0, 500))
|
||||
|
||||
|
||||
def test_imshow_infos():
|
||||
tmp_dir = osp.join(tempfile.gettempdir(), 'image_infos')
|
||||
tmp_filename = osp.join(tmp_dir, 'image.jpg')
|
||||
|
||||
image = np.ones((10, 10, 3), np.uint8)
|
||||
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
|
||||
out_image = vis.imshow_infos(
|
||||
image, result, out_file=tmp_filename, show=False)
|
||||
assert osp.isfile(tmp_filename)
|
||||
assert image.shape == out_image.shape
|
||||
assert not np.allclose(image, out_image)
|
||||
os.remove(tmp_filename)
|
||||
|
||||
# test grayscale images
|
||||
image = np.ones((10, 10), np.uint8)
|
||||
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
|
||||
out_image = vis.imshow_infos(
|
||||
image, result, out_file=tmp_filename, show=False)
|
||||
assert osp.isfile(tmp_filename)
|
||||
assert image.shape == out_image.shape[:2]
|
||||
os.remove(tmp_filename)
|
||||
|
||||
|
||||
def test_figure_context_manager():
|
||||
# test show multiple images with the same figure.
|
||||
images = [
|
||||
np.random.randint(0, 255, (100, 100, 3), np.uint8) for _ in range(5)
|
||||
]
|
||||
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
|
||||
|
||||
with vis.ImshowInfosContextManager() as manager:
|
||||
fig_show = manager.fig_show
|
||||
fig_save = manager.fig_save
|
||||
|
||||
# Test time out
|
||||
fig_show.canvas.start_event_loop = MagicMock()
|
||||
fig_show.canvas.end_event_loop = MagicMock()
|
||||
for image in images:
|
||||
ret, out_image = manager.put_img_infos(image, result, show=True)
|
||||
assert ret == 0
|
||||
assert image.shape == out_image.shape
|
||||
assert not np.allclose(image, out_image)
|
||||
assert fig_show is manager.fig_show
|
||||
assert fig_save is manager.fig_save
|
||||
|
||||
# Test continue key
|
||||
fig_show.canvas.start_event_loop = (
|
||||
lambda _: fig_show.canvas.key_press_event(' '))
|
||||
for image in images:
|
||||
ret, out_image = manager.put_img_infos(image, result, show=True)
|
||||
assert ret == 0
|
||||
assert image.shape == out_image.shape
|
||||
assert not np.allclose(image, out_image)
|
||||
assert fig_show is manager.fig_show
|
||||
assert fig_save is manager.fig_save
|
||||
|
||||
# Test close figure manually
|
||||
fig_show = manager.fig_show
|
||||
|
||||
def destroy(*_, **__):
|
||||
fig_show.canvas.close_event()
|
||||
plt.close(fig_show)
|
||||
|
||||
fig_show.canvas.start_event_loop = destroy
|
||||
ret, out_image = manager.put_img_infos(images[0], result, show=True)
|
||||
assert ret == 1
|
||||
assert image.shape == out_image.shape
|
||||
assert not np.allclose(image, out_image)
|
||||
assert fig_save is manager.fig_save
|
|
@ -2,12 +2,12 @@
|
|||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import pkg_resources
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pkg_resources
|
||||
from mmcv import Config, DictAction
|
||||
from mmcv.utils import to_2tuple
|
||||
from torch.nn import BatchNorm1d, BatchNorm2d, GroupNorm, LayerNorm
|
||||
|
|
Loading…
Reference in New Issue