[Feature] Add `ClsVisualizer`.

pull/913/head
mzr1996 2022-05-18 16:13:30 +00:00
parent 27e685fe10
commit 0537c4d70c
17 changed files with 262 additions and 1458 deletions

View File

@ -4,3 +4,4 @@ from .evaluation import * # noqa: F401, F403
from .hook import * # noqa: F401, F403
from .optimizers import * # noqa: F401, F403
from .utils import * # noqa: F401, F403
from .visualization import * # noqa: F401, F403

View File

@ -10,8 +10,8 @@ import mmcv
import torch
import torch.nn as nn
from mmcv.runner import EpochBasedRunner, get_dist_info
from mmengine.logging import print_log
from mmengine.hooks import Hook
from mmengine.logging import print_log
from torch.functional import Tensor
from torch.nn import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm

View File

@ -1,8 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .image import (BaseFigureContextManager, ImshowInfosContextManager,
color_val_matplotlib, imshow_infos)
from .cls_visualizer import ClsVisualizer
__all__ = [
'BaseFigureContextManager', 'ImshowInfosContextManager', 'imshow_infos',
'color_val_matplotlib'
]
__all__ = ['ClsVisualizer']

View File

@ -0,0 +1,153 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import mmcv
import numpy as np
from mmengine import Visualizer
from mmengine.dist import master_only
from mmcls.core import ClsDataSample
from mmcls.registry import VISUALIZERS
@VISUALIZERS.register_module()
class ClsVisualizer(Visualizer):
"""Universal Visualizer for classification task.
Args:
name (str): Name of the instance. Defaults to 'visualizer'.
image (np.ndarray, optional): the origin image to draw. The format
should be RGB. Defaults to None.
vis_backends (list, optional): Visual backend config list.
Default to None.
save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data.
fig_save_cfg (dict): Keyword parameters of figure for saving.
Defaults to empty dict.
fig_show_cfg (dict): Keyword parameters of figure for showing.
Defaults to empty dict.
Examples:
>>> import torch
>>> import mmcv
>>> from pathlib import Path
>>> from mmcls.core import ClsDataSample, ClsVisualizer
>>> # Example image
>>> img = mmcv.imread("./demo/bird.JPEG", channel_order='rgb')
>>> # Example annotation
>>> data_sample = ClsDataSample().set_gt_label(1).set_pred_label(1).\
... set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
>>> # Setup the visualizer
>>> vis = ClsVisualizer(
... save_dir="./outputs",
... vis_backends=[dict(type='LocalVisBackend')])
>>> # Set classes names
>>> vis.dataset_meta = {'CLASSES': ['cat', 'bird', 'dog']}
>>> # Show the example image with annotation in a figure.
>>> # And it will ignore all preset storage backends.
>>> vis.add_datasample('res', img, data_sample, show=True)
>>> # Save the visualization result by the specified storage backends.
>>> vis.add_datasample('res', img, data_sample)
>>> assert Path('./outputs/vis_data/vis_image/res_0.png').exists()
>>> # Save another visualization result with the same name.
>>> vis.add_datasample('res', img, data_sample, step=1)
>>> assert Path('./outputs/vis_data/vis_image/res_1.png').exists()
"""
@master_only
def add_datasample(self,
name: str,
image: np.ndarray,
data_sample: Optional[ClsDataSample] = None,
draw_gt: bool = True,
draw_pred: bool = True,
draw_score: bool = True,
show: bool = False,
text_cfg: dict = dict(),
wait_time: float = 0,
out_file: Optional[str] = None,
step: int = 0) -> None:
"""Draw datasample and save to all backends.
- If ``show`` is True, all storage backends are ignored and then
displayed in a local window.
- If the ``out_file`` parameter is specified, the drawn image
will be additionally saved to ``out_file``. It is usually used
in script mode like ``image_demo.py``
Args:
name (str): The image identifier.
image (np.ndarray): The image to draw.
data_sample (:obj:`ClsDataSample`, optional): The annotation of the
image. Default to None.
draw_gt (bool): Whether to draw ground truth labels.
Default to True.
draw_pred (bool): Whether to draw prediction labels.
Default to True.
draw_score (bool): Whether to draw the prediction scores
of prediction categories. Default to True.
show (bool): Whether to display the drawn image. Default to False.
text_cfg (dict): Extra text setting, which accepts
arguments of :attr:`mmengine.Visualizer.draw_texts`.
Defaults to an empty dict.
wait_time (float): The interval of show (s). Default to 0, which
means "forever".
out_file (str, optional): Extra path to save the visualization
result. Whether specified or not, the visualizer will still
save the results by its storage backends. Default to None.
step (int): Global step value to record. Default to 0.
"""
classes = None
if self.dataset_meta is not None:
classes = self.dataset_meta.get('CLASSES', None)
texts = []
self.set_image(image)
if draw_gt and 'gt_label' in data_sample:
gt_label = data_sample.gt_label
idx = gt_label.label.tolist()
class_labels = [''] * len(idx)
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))]
prefix = 'Ground truth: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
if draw_pred and 'pred_label' in data_sample:
pred_label = data_sample.pred_label
idx = pred_label.label.tolist()
score_labels = [''] * len(idx)
class_labels = [''] * len(idx)
if draw_score and 'score' in pred_label:
score_labels = [
f', {pred_label.score[i].item():.2f}' for i in idx
]
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [
str(idx[i]) + score_labels[i] + class_labels[i]
for i in range(len(idx))
]
prefix = 'Prediction: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
text_cfg = {
'positions': np.array([(5, 5)]),
'font_families': 'monospace',
'colors': 'white',
'bboxes': dict(facecolor='black', alpha=0.5, boxstyle='Round'),
**text_cfg
}
self.draw_texts('\n'.join(texts), **text_cfg)
drawn_img = self.get_image()
if show:
self.show(drawn_img, win_name=name, wait_time=wait_time)
else:
self.add_image(name, drawn_img, step=step)
if out_file is not None:
mmcv.imwrite(drawn_img[..., ::-1], out_file)

View File

@ -1,343 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import matplotlib.pyplot as plt
import mmcv
import numpy as np
from matplotlib.backend_bases import CloseEvent
# A small value
EPS = 1e-2
def color_val_matplotlib(color):
"""Convert various input in BGR order to normalized RGB matplotlib color
tuples,
Args:
color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Color inputs
Returns:
tuple[float]: A tuple of 3 normalized floats indicating RGB channels.
"""
color = mmcv.color_val(color)
color = [color / 255 for color in color[::-1]]
return tuple(color)
class BaseFigureContextManager:
"""Context Manager to reuse matplotlib figure.
It provides a figure for saving and a figure for showing to support
different settings.
Args:
axis (bool): Whether to show the axis lines.
fig_save_cfg (dict): Keyword parameters of figure for saving.
Defaults to empty dict.
fig_show_cfg (dict): Keyword parameters of figure for showing.
Defaults to empty dict.
"""
def __init__(self, axis=False, fig_save_cfg={}, fig_show_cfg={}) -> None:
self.is_inline = 'inline' in plt.get_backend()
# Because save and show need different figure size
# We set two figure and axes to handle save and show
self.fig_save: plt.Figure = None
self.fig_save_cfg = fig_save_cfg
self.ax_save: plt.Axes = None
self.fig_show: plt.Figure = None
self.fig_show_cfg = fig_show_cfg
self.ax_show: plt.Axes = None
self.axis = axis
def __enter__(self):
if not self.is_inline:
# If use inline backend, we cannot control which figure to show,
# so disable the interactive fig_show, and put the initialization
# of fig_save to `prepare` function.
self._initialize_fig_save()
self._initialize_fig_show()
return self
def _initialize_fig_save(self):
fig = plt.figure(**self.fig_save_cfg)
ax = fig.add_subplot()
# remove white edges by set subplot margin
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
self.fig_save, self.ax_save = fig, ax
def _initialize_fig_show(self):
# fig_save will be resized to image size, only fig_show needs fig_size.
fig = plt.figure(**self.fig_show_cfg)
ax = fig.add_subplot()
# remove white edges by set subplot margin
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
self.fig_show, self.ax_show = fig, ax
def __exit__(self, exc_type, exc_value, traceback):
if self.is_inline:
# If use inline backend, whether to close figure depends on if
# users want to show the image.
return
plt.close(self.fig_save)
plt.close(self.fig_show)
def prepare(self):
if self.is_inline:
# if use inline backend, just rebuild the fig_save.
self._initialize_fig_save()
self.ax_save.cla()
self.ax_save.axis(self.axis)
return
# If users force to destroy the window, rebuild fig_show.
if not plt.fignum_exists(self.fig_show.number):
self._initialize_fig_show()
# Clear all axes
self.ax_save.cla()
self.ax_save.axis(self.axis)
self.ax_show.cla()
self.ax_show.axis(self.axis)
def wait_continue(self, timeout=0, continue_key=' ') -> int:
"""Show the image and wait for the user's input.
This implementation refers to
https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py
Args:
timeout (int): If positive, continue after ``timeout`` seconds.
Defaults to 0.
continue_key (str): The key for users to continue. Defaults to
the space key.
Returns:
int: If zero, means time out or the user pressed ``continue_key``,
and if one, means the user closed the show figure.
""" # noqa: E501
if self.is_inline:
# If use inline backend, interactive input and timeout is no use.
return
if self.fig_show.canvas.manager:
# Ensure that the figure is shown
self.fig_show.show()
while True:
# Connect the events to the handler function call.
event = None
def handler(ev):
# Set external event variable
nonlocal event
# Qt backend may fire two events at the same time,
# use a condition to avoid missing close event.
event = ev if not isinstance(event, CloseEvent) else event
self.fig_show.canvas.stop_event_loop()
cids = [
self.fig_show.canvas.mpl_connect(name, handler)
for name in ('key_press_event', 'close_event')
]
try:
self.fig_show.canvas.start_event_loop(timeout)
finally: # Run even on exception like ctrl-c.
# Disconnect the callbacks.
for cid in cids:
self.fig_show.canvas.mpl_disconnect(cid)
if isinstance(event, CloseEvent):
return 1 # Quit for close.
elif event is None or event.key == continue_key:
return 0 # Quit for continue.
class ImshowInfosContextManager(BaseFigureContextManager):
"""Context Manager to reuse matplotlib figure and put infos on images.
Args:
fig_size (tuple[int]): Size of the figure to show image.
Examples:
>>> import mmcv
>>> from mmcls.core import visualization as vis
>>> img1 = mmcv.imread("./1.png")
>>> info1 = {'class': 'cat', 'label': 0}
>>> img2 = mmcv.imread("./2.png")
>>> info2 = {'class': 'dog', 'label': 1}
>>> with vis.ImshowInfosContextManager() as manager:
... # Show img1
... manager.put_img_infos(img1, info1)
... # Show img2 on the same figure and save output image.
... manager.put_img_infos(
... img2, info2, out_file='./2_out.png')
"""
def __init__(self, fig_size=(15, 10)):
super().__init__(
axis=False,
# A proper dpi for image save with default font size.
fig_save_cfg=dict(frameon=False, dpi=36),
fig_show_cfg=dict(frameon=False, figsize=fig_size))
def _put_text(self, ax, text, x, y, text_color, font_size):
ax.text(
x,
y,
f'{text}',
bbox={
'facecolor': 'black',
'alpha': 0.7,
'pad': 0.2,
'edgecolor': 'none',
'boxstyle': 'round'
},
color=text_color,
fontsize=font_size,
family='monospace',
verticalalignment='top',
horizontalalignment='left')
def put_img_infos(self,
img,
infos,
text_color='white',
font_size=26,
row_width=20,
win_name='',
show=True,
wait_time=0,
out_file=None):
"""Show image with extra information.
Args:
img (str | ndarray): The image to be displayed.
infos (dict): Extra infos to display in the image.
text_color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Extra infos
display color. Defaults to 'white'.
font_size (int): Extra infos display font size. Defaults to 26.
row_width (int): width between each row of results on the image.
win_name (str): The image title. Defaults to ''
show (bool): Whether to show the image. Defaults to True.
wait_time (int): How many seconds to display the image.
Defaults to 0.
out_file (Optional[str]): The filename to write the image.
Defaults to None.
Returns:
np.ndarray: The image with extra infomations.
"""
self.prepare()
text_color = color_val_matplotlib(text_color)
img = mmcv.imread(img).astype(np.uint8)
x, y = 3, row_width // 2
img = mmcv.bgr2rgb(img)
width, height = img.shape[1], img.shape[0]
img = np.ascontiguousarray(img)
# add a small EPS to avoid precision lost due to matplotlib's
# truncation (https://github.com/matplotlib/matplotlib/issues/15363)
dpi = self.fig_save.get_dpi()
self.fig_save.set_size_inches((width + EPS) / dpi,
(height + EPS) / dpi)
for k, v in infos.items():
if isinstance(v, float):
v = f'{v:.2f}'
label_text = f'{k}: {v}'
self._put_text(self.ax_save, label_text, x, y, text_color,
font_size)
if show and not self.is_inline:
self._put_text(self.ax_show, label_text, x, y, text_color,
font_size)
y += row_width
self.ax_save.imshow(img)
stream, _ = self.fig_save.canvas.print_to_buffer()
buffer = np.frombuffer(stream, dtype='uint8')
img_rgba = buffer.reshape(height, width, 4)
rgb, _ = np.split(img_rgba, [3], axis=2)
img_save = rgb.astype('uint8')
img_save = mmcv.rgb2bgr(img_save)
if out_file is not None:
mmcv.imwrite(img_save, out_file)
ret = 0
if show and not self.is_inline:
# Reserve some space for the tip.
self.ax_show.set_title(win_name)
self.ax_show.set_ylim(height + 20)
self.ax_show.text(
width // 2,
height + 18,
'Press SPACE to continue.',
ha='center',
fontsize=font_size)
self.ax_show.imshow(img)
# Refresh canvas, necessary for Qt5 backend.
self.fig_show.canvas.draw()
ret = self.wait_continue(timeout=wait_time)
elif (not show) and self.is_inline:
# If use inline backend, we use fig_save to show the image
# So we need to close it if users don't want to show.
plt.close(self.fig_save)
return ret, img_save
def imshow_infos(img,
infos,
text_color='white',
font_size=26,
row_width=20,
win_name='',
show=True,
fig_size=(15, 10),
wait_time=0,
out_file=None):
"""Show image with extra information.
Args:
img (str | ndarray): The image to be displayed.
infos (dict): Extra infos to display in the image.
text_color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Extra infos
display color. Defaults to 'white'.
font_size (int): Extra infos display font size. Defaults to 26.
row_width (int): width between each row of results on the image.
win_name (str): The image title. Defaults to ''
show (bool): Whether to show the image. Defaults to True.
fig_size (tuple): Image show figure size. Defaults to (15, 10).
wait_time (int): How many seconds to display the image. Defaults to 0.
out_file (Optional[str]): The filename to write the image.
Defaults to None.
Returns:
np.ndarray: The image with extra infomations.
"""
with ImshowInfosContextManager(fig_size=fig_size) as manager:
_, img = manager.put_img_infos(
img,
infos,
text_color=text_color,
font_size=font_size,
row_width=row_width,
win_name=win_name,
show=show,
wait_time=wait_time,
out_file=out_file)
return img

View File

@ -3,13 +3,10 @@ from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from typing import Sequence
import mmcv
import torch
import torch.distributed as dist
from mmcv.runner import BaseModule, auto_fp16
from mmcls.core.visualization import imshow_infos
class BaseClassifier(BaseModule, metaclass=ABCMeta):
"""Base class for classifiers."""
@ -174,51 +171,3 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
return outputs
def show_result(self,
img,
result,
text_color='white',
font_scale=0.5,
row_width=20,
show=False,
fig_size=(15, 10),
win_name='',
wait_time=0,
out_file=None):
"""Draw `result` over `img`.
Args:
img (str or ndarray): The image to be displayed.
result (dict): The classification results to draw over `img`.
text_color (str or tuple or :obj:`Color`): Color of texts.
font_scale (float): Font scales of texts.
row_width (int): width between each row of results on the image.
show (bool): Whether to show the image.
Default: False.
fig_size (tuple): Image show figure size. Defaults to (15, 10).
win_name (str): The window name.
wait_time (int): How many seconds to display the image.
Defaults to 0.
out_file (str or None): The filename to write the image.
Default: None.
Returns:
img (ndarray): Image with overlaid results.
"""
img = mmcv.imread(img)
img = img.copy()
img = imshow_infos(
img,
result,
text_color=text_color,
font_size=int(font_scale * 50),
row_width=row_width,
win_name=win_name,
show=show,
fig_size=fig_size,
wait_time=wait_time,
out_file=out_file)
return img

View File

@ -45,21 +45,21 @@ DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
# mangage all kinds of modules inheriting `nn.Module`
# manage all kinds of modules inheriting `nn.Module`
MODELS = Registry('model', parent=MMENGINE_MODELS)
# mangage all kinds of model wrappers like 'MMDistributedDataParallel'
# manage all kinds of model wrappers like 'MMDistributedDataParallel'
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
# mangage all kinds of weight initialization modules like `Uniform`
# manage all kinds of weight initialization modules like `Uniform`
WEIGHT_INITIALIZERS = Registry(
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
# Registries For Optimizer and the related
# mangage all kinds of optimizers like `SGD` and `Adam`
# manage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
# manage constructors that customize the optimization hyperparameters.
OPTIMIZER_CONSTRUCTORS = Registry(
'optimizer constructor', parent=MMENGINE_OPTIMIZER_CONSTRUCTORS)
# mangage all kinds of parameter schedulers like `MultiStepLR`
# manage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry(
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)

View File

@ -0,0 +1,98 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase
from unittest.mock import patch
import numpy as np
import torch
from mmcls.core import ClsDataSample, ClsVisualizer
class TestClsVisualizer(TestCase):
def setUp(self) -> None:
super().setUp()
tmpdir = tempfile.TemporaryDirectory()
self.tmpdir = tmpdir
self.vis = ClsVisualizer(
save_dir=tmpdir.name,
vis_backends=[dict(type='LocalVisBackend')],
)
def test_add_datasample(self):
image = np.ones((10, 10, 3), np.uint8)
data_sample = ClsDataSample().set_gt_label(1).set_pred_label(1).\
set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
# Test show
def mock_show(drawn_img, win_name, wait_time):
self.assertFalse((image == drawn_img).all())
self.assertEqual(win_name, 'test')
self.assertEqual(wait_time, 0)
with patch.object(self.vis, 'show', mock_show):
self.vis.add_datasample(
'test', image=image, data_sample=data_sample, show=True)
# Test out_file
out_file = osp.join(self.tmpdir.name, 'results.png')
self.vis.add_datasample(
'test', image=image, data_sample=data_sample, out_file=out_file)
self.assertTrue(osp.exists(out_file))
# Test storage backend.
save_file = osp.join(self.tmpdir.name, 'vis_data/vis_image/test_0.png')
self.assertTrue(osp.exists(save_file))
# Test with dataset_meta
self.vis.dataset_meta = {'CLASSES': ['cat', 'bird', 'dog']}
def test_texts(text, *_, **__):
self.assertEqual(
text, '\n'.join([
'Ground truth: 1 (bird)',
'Prediction: 1, 0.80 (bird)',
]))
with patch.object(self.vis, 'draw_texts', test_texts):
self.vis.add_datasample(
'test', image=image, data_sample=data_sample)
# Test without pred_label
def test_texts(text, *_, **__):
self.assertEqual(text, '\n'.join([
'Ground truth: 1 (bird)',
]))
with patch.object(self.vis, 'draw_texts', test_texts):
self.vis.add_datasample(
'test', image=image, data_sample=data_sample, draw_pred=False)
# Test without gt_label
def test_texts(text, *_, **__):
self.assertEqual(text, '\n'.join([
'Prediction: 1, 0.80 (bird)',
]))
with patch.object(self.vis, 'draw_texts', test_texts):
self.vis.add_datasample(
'test', image=image, data_sample=data_sample, draw_gt=False)
# Test without score
del data_sample.pred_label.score
def test_texts(text, *_, **__):
self.assertEqual(
text, '\n'.join([
'Ground truth: 1 (bird)',
'Prediction: 1 (bird)',
]))
with patch.object(self.vis, 'draw_texts', test_texts):
self.vis.add_datasample(
'test', image=image, data_sample=data_sample)
def tearDown(self):
self.tmpdir.cleanup()

View File

@ -1,118 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmcv import Config
from mmdet.apis import inference_detector
from mmdet.models import build_detector
from mmcls.models import (MobileNetV2, MobileNetV3, RegNet, ResNeSt, ResNet,
ResNeXt, SEResNet, SEResNeXt, SwinTransformer,
TIMMBackbone)
from mmcls.models.backbones.timm_backbone import timm
backbone_configs = dict(
mobilenetv2=dict(
backbone=dict(
type='mmcls.MobileNetV2',
widen_factor=1.0,
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True),
out_indices=(4, 7)),
out_channels=[96, 1280]),
mobilenetv3=dict(
backbone=dict(
type='mmcls.MobileNetV3',
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True),
out_indices=range(7, 12)),
out_channels=[48, 48, 96, 96, 96]),
regnet=dict(
backbone=dict(type='mmcls.RegNet', arch='regnetx_400mf'),
out_channels=384),
resnext=dict(
backbone=dict(
type='mmcls.ResNeXt', depth=50, groups=32, width_per_group=4),
out_channels=2048),
resnet=dict(
backbone=dict(type='mmcls.ResNet', depth=50), out_channels=2048),
seresnet=dict(
backbone=dict(type='mmcls.SEResNet', depth=50), out_channels=2048),
seresnext=dict(
backbone=dict(
type='mmcls.SEResNeXt', depth=50, groups=32, width_per_group=4),
out_channels=2048),
resnest=dict(
backbone=dict(
type='mmcls.ResNeSt',
depth=50,
radix=2,
reduction_factor=4,
out_indices=(0, 1, 2, 3)),
out_channels=[256, 512, 1024, 2048]),
swin=dict(
backbone=dict(
type='mmcls.SwinTransformer',
arch='small',
drop_path_rate=0.2,
img_size=800,
out_indices=(2, 3)),
out_channels=[384, 768]),
timm_efficientnet=dict(
backbone=dict(
type='mmcls.TIMMBackbone',
model_name='efficientnet_b1',
features_only=True,
pretrained=False,
out_indices=(1, 2, 3, 4)),
out_channels=[24, 40, 112, 320]),
timm_resnet=dict(
backbone=dict(
type='mmcls.TIMMBackbone',
model_name='resnet50',
features_only=True,
pretrained=False,
out_indices=(1, 2, 3, 4)),
out_channels=[256, 512, 1024, 2048]))
module_mapping = {
'mobilenetv2': MobileNetV2,
'mobilenetv3': MobileNetV3,
'regnet': RegNet,
'resnext': ResNeXt,
'resnet': ResNet,
'seresnext': SEResNeXt,
'seresnet': SEResNet,
'resnest': ResNeSt,
'swin': SwinTransformer,
'timm_efficientnet': TIMMBackbone,
'timm_resnet': TIMMBackbone
}
def test_mmdet_inference():
config_path = './tests/data/retinanet.py'
rng = np.random.RandomState(0)
img1 = rng.rand(100, 100, 3)
for module_name, backbone_config in backbone_configs.items():
module = module_mapping[module_name]
if module is TIMMBackbone and timm is None:
print(f'skip {module_name} because timm is not available')
continue
print(f'test {module_name}')
config = Config.fromfile(config_path)
config.model.backbone = backbone_config['backbone']
out_channels = backbone_config['out_channels']
if isinstance(out_channels, int):
config.model.neck = None
config.model.bbox_head.in_channels = out_channels
anchor_generator = config.model.bbox_head.anchor_generator
anchor_generator.strides = anchor_generator.strides[:1]
else:
config.model.neck.in_channels = out_channels
model = build_detector(config.model)
assert isinstance(model.backbone, module)
model.cfg = config
model.eval()
result = inference_detector(model, img1)
assert len(result) == config.num_classes

View File

@ -1,9 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from copy import deepcopy
import numpy as np
import torch
from mmcv import ConfigDict
@ -90,20 +87,6 @@ def test_image_classifier():
model = CLASSIFIERS.build(model_cfg_)
assert model.init_cfg == dict(type='Pretrained', checkpoint='checkpoint')
# test show_result
img = np.random.randint(0, 256, (224, 224, 3)).astype(np.uint8)
result = dict(pred_class='cat', pred_label=0, pred_score=0.9)
with tempfile.TemporaryDirectory() as tmpdir:
out_file = osp.join(tmpdir, 'out.png')
model.show_result(img, result, out_file=out_file)
assert osp.exists(out_file)
with tempfile.TemporaryDirectory() as tmpdir:
out_file = osp.join(tmpdir, 'out.png')
model.show_result(img, result, out_file=out_file)
assert osp.exists(out_file)
def test_image_classifier_with_mixup():
# Test mixup in ImageClassifier

View File

@ -1,158 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import shutil
import tempfile
import numpy as np
import pytest
import torch
import torch.nn as nn
from mmcv.runner import build_runner
from mmengine.hooks import Hook, IterTimerHook
from torch.utils.data import DataLoader
import mmcls.core # noqa: F401
def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)
self.conv = nn.Conv2d(3, 3, 3)
def forward(self, x):
return self.linear(x)
def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
def val_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
model = Model()
if multi_optimziers:
optimizer = {
'model1':
torch.optim.SGD(model.linear.parameters(), lr=0.02, momentum=0.95),
'model2':
torch.optim.SGD(model.conv.parameters(), lr=0.01, momentum=0.9),
}
else:
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
tmp_dir = tempfile.mkdtemp()
runner = build_runner(
dict(type=runner_type),
default_args=dict(
model=model,
work_dir=tmp_dir,
optimizer=optimizer,
logger=logging.getLogger(),
max_epochs=max_epochs,
max_iters=max_iters))
return runner
def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):
log_config = dict(
interval=1, hooks=[
dict(type='TextLoggerHook'),
])
runner = _build_demo_runner_without_hook(runner_type, max_epochs,
max_iters, multi_optimziers)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_logger_hooks(log_config)
return runner
class ValueCheckHook(Hook):
def __init__(self, check_dict, by_epoch=False):
super().__init__()
self.check_dict = check_dict
self.by_epoch = by_epoch
def after_iter(self, runner):
if self.by_epoch:
return
if runner.iter in self.check_dict:
for attr, target in self.check_dict[runner.iter].items():
value = eval(f'runner.{attr}')
assert np.isclose(value, target), \
(f'The value of `runner.{attr}` is {value}, '
f'not equals to {target}')
def after_epoch(self, runner):
if not self.by_epoch:
return
if runner.epoch in self.check_dict:
for attr, target in self.check_dict[runner.epoch]:
value = eval(f'runner.{attr}')
assert np.isclose(value, target), \
(f'The value of `runner.{attr}` is {value}, '
f'not equals to {target}')
@pytest.mark.parametrize('multi_optimziers', (True, False))
def test_cosine_cooldown_hook(multi_optimziers):
"""xdoctest -m tests/test_hooks.py test_cosine_runner_hook."""
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
# add momentum LR scheduler
hook_cfg = dict(
type='CosineAnnealingCooldownLrUpdaterHook',
by_epoch=False,
cool_down_time=2,
cool_down_ratio=0.1,
min_lr_ratio=0.1,
warmup_iters=2,
warmup_ratio=0.9)
runner.register_hook_from_cfg(hook_cfg)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
runner.register_hook(IterTimerHook())
if multi_optimziers:
check_hook = ValueCheckHook({
0: {
'current_lr()["model1"][0]': 0.02,
'current_lr()["model2"][0]': 0.01,
},
5: {
'current_lr()["model1"][0]': 0.0075558491,
'current_lr()["model2"][0]': 0.0037779246,
},
9: {
'current_lr()["model1"][0]': 0.0002,
'current_lr()["model2"][0]': 0.0001,
}
})
else:
check_hook = ValueCheckHook({
0: {
'current_lr()[0]': 0.02,
},
5: {
'current_lr()[0]': 0.0075558491,
},
9: {
'current_lr()[0]': 0.0002,
}
})
runner.register_hook(check_hook, priority='LOWEST')
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)

View File

@ -1,84 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import tempfile
from unittest.mock import MagicMock
import mmcv.runner as mmcv_runner
import pytest
import torch
from mmcv.runner import obj_from_dict
from torch.utils.data import DataLoader, Dataset
from mmcls.core.hook import ClassNumCheckHook
from mmcls.models.heads.base_head import BaseHead
class ExampleDataset(Dataset):
def __init__(self, CLASSES):
self.CLASSES = CLASSES
def __getitem__(self, idx):
results = dict(img=torch.tensor([1]), img_metas=dict())
return results
def __len__(self):
return 1
class ExampleHead(BaseHead):
def __init__(self, init_cfg=None):
super(BaseHead, self).__init__(init_cfg)
self.num_classes = 4
def forward_train(self, x, gt_label=None, **kwargs):
pass
class ExampleModel(torch.nn.Module):
def __init__(self):
super(ExampleModel, self).__init__()
self.test_cfg = None
self.conv = torch.nn.Conv2d(3, 3, 3)
self.head = ExampleHead()
def forward(self, img, img_metas, test_mode=False, **kwargs):
return img
def train_step(self, data_batch, optimizer):
loss = self.forward(**data_batch)
return dict(loss=loss)
@pytest.mark.parametrize('runner_type',
['EpochBasedRunner', 'IterBasedRunner'])
@pytest.mark.parametrize(
'CLASSES', [None, ('A', 'B', 'C', 'D', 'E'), ('A', 'B', 'C', 'D')])
def test_num_class_hook(runner_type, CLASSES):
test_dataset = ExampleDataset(CLASSES)
loader = DataLoader(test_dataset, batch_size=1)
model = ExampleModel()
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer = obj_from_dict(optim_cfg, torch.optim,
dict(params=model.parameters()))
with tempfile.TemporaryDirectory() as tmpdir:
num_class_hook = ClassNumCheckHook()
logger_mock = MagicMock(spec=logging.Logger)
runner = getattr(mmcv_runner, runner_type)(
model=model,
optimizer=optimizer,
work_dir=tmpdir,
logger=logger_mock,
max_epochs=1)
runner.register_hook(num_class_hook)
if CLASSES is None:
runner.run([loader], [('train', 1)], 1)
logger_mock.warning.assert_called()
elif len(CLASSES) != 4:
with pytest.raises(AssertionError):
runner.run([loader], [('train', 1)], 1)
else:
runner.run([loader], [('train', 1)], 1)

View File

@ -1,309 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from collections import OrderedDict
from copy import deepcopy
from typing import Iterable
import torch
import torch.nn as nn
from mmcv.runner import build_optimizer
from mmcv.runner.optimizer.builder import OPTIMIZERS
from mmcv.utils.registry import build_from_cfg
from torch.autograd import Variable
from torch.optim.optimizer import Optimizer
import mmcls.core # noqa: F401
base_lr = 0.01
base_wd = 0.0001
def assert_equal(x, y):
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
torch.testing.assert_allclose(x, y.to(x.device))
elif isinstance(x, OrderedDict) and isinstance(y, OrderedDict):
for x_value, y_value in zip(x.values(), y.values()):
assert_equal(x_value, y_value)
elif isinstance(x, dict) and isinstance(y, dict):
assert x.keys() == y.keys()
for key in x.keys():
assert_equal(x[key], y[key])
elif isinstance(x, str) and isinstance(y, str):
assert x == y
elif isinstance(x, Iterable) and isinstance(y, Iterable):
assert len(x) == len(y)
for x_item, y_item in zip(x, y):
assert_equal(x_item, y_item)
else:
assert x == y
class SubModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(2, 2, kernel_size=1, groups=2)
self.gn = nn.GroupNorm(2, 2)
self.fc = nn.Linear(2, 2)
self.param1 = nn.Parameter(torch.ones(1))
def forward(self, x):
return x
class ExampleModel(nn.Module):
def __init__(self):
super().__init__()
self.param1 = nn.Parameter(torch.ones(1))
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
self.bn = nn.BatchNorm2d(2)
self.sub = SubModel()
self.fc = nn.Linear(2, 1)
def forward(self, x):
return x
def check_lamb_optimizer(optimizer,
model,
bias_lr_mult=1,
bias_decay_mult=1,
norm_decay_mult=1,
dwconv_decay_mult=1):
param_groups = optimizer.param_groups
assert isinstance(optimizer, Optimizer)
assert optimizer.defaults['lr'] == base_lr
assert optimizer.defaults['weight_decay'] == base_wd
model_parameters = list(model.parameters())
assert len(param_groups) == len(model_parameters)
for i, param in enumerate(model_parameters):
param_group = param_groups[i]
assert torch.equal(param_group['params'][0], param)
# param1
param1 = param_groups[0]
assert param1['lr'] == base_lr
assert param1['weight_decay'] == base_wd
# conv1.weight
conv1_weight = param_groups[1]
assert conv1_weight['lr'] == base_lr
assert conv1_weight['weight_decay'] == base_wd
# conv2.weight
conv2_weight = param_groups[2]
assert conv2_weight['lr'] == base_lr
assert conv2_weight['weight_decay'] == base_wd
# conv2.bias
conv2_bias = param_groups[3]
assert conv2_bias['lr'] == base_lr * bias_lr_mult
assert conv2_bias['weight_decay'] == base_wd * bias_decay_mult
# bn.weight
bn_weight = param_groups[4]
assert bn_weight['lr'] == base_lr
assert bn_weight['weight_decay'] == base_wd * norm_decay_mult
# bn.bias
bn_bias = param_groups[5]
assert bn_bias['lr'] == base_lr
assert bn_bias['weight_decay'] == base_wd * norm_decay_mult
# sub.param1
sub_param1 = param_groups[6]
assert sub_param1['lr'] == base_lr
assert sub_param1['weight_decay'] == base_wd
# sub.conv1.weight
sub_conv1_weight = param_groups[7]
assert sub_conv1_weight['lr'] == base_lr
assert sub_conv1_weight['weight_decay'] == base_wd * dwconv_decay_mult
# sub.conv1.bias
sub_conv1_bias = param_groups[8]
assert sub_conv1_bias['lr'] == base_lr * bias_lr_mult
assert sub_conv1_bias['weight_decay'] == base_wd * dwconv_decay_mult
# sub.gn.weight
sub_gn_weight = param_groups[9]
assert sub_gn_weight['lr'] == base_lr
assert sub_gn_weight['weight_decay'] == base_wd * norm_decay_mult
# sub.gn.bias
sub_gn_bias = param_groups[10]
assert sub_gn_bias['lr'] == base_lr
assert sub_gn_bias['weight_decay'] == base_wd * norm_decay_mult
# sub.fc1.weight
sub_fc_weight = param_groups[11]
assert sub_fc_weight['lr'] == base_lr
assert sub_fc_weight['weight_decay'] == base_wd
# sub.fc1.bias
sub_fc_bias = param_groups[12]
assert sub_fc_bias['lr'] == base_lr * bias_lr_mult
assert sub_fc_bias['weight_decay'] == base_wd * bias_decay_mult
# fc1.weight
fc_weight = param_groups[13]
assert fc_weight['lr'] == base_lr
assert fc_weight['weight_decay'] == base_wd
# fc1.bias
fc_bias = param_groups[14]
assert fc_bias['lr'] == base_lr * bias_lr_mult
assert fc_bias['weight_decay'] == base_wd * bias_decay_mult
def _test_state_dict(weight, bias, input, constructor):
weight = Variable(weight, requires_grad=True)
bias = Variable(bias, requires_grad=True)
inputs = Variable(input)
def fn_base(optimizer, weight, bias):
optimizer.zero_grad()
i = input_cuda if weight.is_cuda else inputs
loss = (weight.mv(i) + bias).pow(2).sum()
loss.backward()
return loss
optimizer = constructor(weight, bias)
fn = functools.partial(fn_base, optimizer, weight, bias)
# Prime the optimizer
for _ in range(20):
optimizer.step(fn)
# Clone the weights and construct new optimizer for them
weight_c = Variable(weight.data.clone(), requires_grad=True)
bias_c = Variable(bias.data.clone(), requires_grad=True)
optimizer_c = constructor(weight_c, bias_c)
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
# Load state dict
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_c.load_state_dict(state_dict_c)
# Run both optimizations in parallel
for _ in range(20):
optimizer.step(fn)
optimizer_c.step(fn_c)
assert_equal(weight, weight_c)
assert_equal(bias, bias_c)
# Make sure state dict wasn't modified
assert_equal(state_dict, state_dict_c)
# Make sure state dict is deterministic with equal
# but not identical parameters
# NOTE: The state_dict of optimizers in PyTorch 1.5 have random keys,
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer_c.state_dict())
keys = state_dict['param_groups'][-1]['params']
keys_c = state_dict_c['param_groups'][-1]['params']
for key, key_c in zip(keys, keys_c):
assert_equal(optimizer.state_dict()['state'][key],
optimizer_c.state_dict()['state'][key_c])
# Make sure repeated parameters have identical representation in state dict
optimizer_c.param_groups.extend(optimizer_c.param_groups)
assert_equal(optimizer_c.state_dict()['param_groups'][0],
optimizer_c.state_dict()['param_groups'][1])
# Check that state dict can be loaded even when we cast parameters
# to a different type and move to a different device.
if not torch.cuda.is_available():
return
input_cuda = Variable(inputs.data.float().cuda())
weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
optimizer_cuda = constructor(weight_cuda, bias_cuda)
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda,
bias_cuda)
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_cuda.load_state_dict(state_dict_c)
# Make sure state dict wasn't modified
assert_equal(state_dict, state_dict_c)
for _ in range(20):
optimizer.step(fn)
optimizer_cuda.step(fn_cuda)
assert_equal(weight, weight_cuda)
assert_equal(bias, bias_cuda)
# validate deepcopy() copies all public attributes
def getPublicAttr(obj):
return set(k for k in obj.__dict__ if not k.startswith('_'))
assert_equal(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))
def _test_basic_cases_template(weight, bias, inputs, constructor,
scheduler_constructors):
"""Copied from PyTorch."""
weight = Variable(weight, requires_grad=True)
bias = Variable(bias, requires_grad=True)
inputs = Variable(inputs)
optimizer = constructor(weight, bias)
schedulers = []
for scheduler_constructor in scheduler_constructors:
schedulers.append(scheduler_constructor(optimizer))
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def fn():
optimizer.zero_grad()
y = weight.mv(inputs)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _ in range(200):
for scheduler in schedulers:
scheduler.step()
optimizer.step(fn)
assert fn().item() < initial_value
def _test_basic_cases(constructor,
scheduler_constructors=None,
ignore_multidevice=False):
"""Copied from PyTorch."""
if scheduler_constructors is None:
scheduler_constructors = []
_test_state_dict(
torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor)
_test_basic_cases_template(
torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor,
scheduler_constructors)
# non-contiguous parameters
_test_basic_cases_template(
torch.randn(10, 5, 2)[..., 0],
torch.randn(10, 2)[..., 0], torch.randn(5), constructor,
scheduler_constructors)
# CUDA
if not torch.cuda.is_available():
return
_test_basic_cases_template(
torch.randn(10, 5).cuda(),
torch.randn(10).cuda(),
torch.randn(5).cuda(), constructor, scheduler_constructors)
# Multi-GPU
if not torch.cuda.device_count() > 1 or ignore_multidevice:
return
_test_basic_cases_template(
torch.randn(10, 5).cuda(0),
torch.randn(10).cuda(1),
torch.randn(5).cuda(0), constructor, scheduler_constructors)
def test_lamb_optimizer():
model = ExampleModel()
optimizer_cfg = dict(
type='Lamb',
lr=base_lr,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=base_wd,
paramwise_cfg=dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1))
optimizer = build_optimizer(model, optimizer_cfg)
check_lamb_optimizer(optimizer, model, **optimizer_cfg['paramwise_cfg'])
_test_basic_cases(lambda weight, bias: build_from_cfg(
dict(type='Lamb', params=[weight, bias], lr=base_lr), OPTIMIZERS))

View File

@ -1,264 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
import torch.nn as nn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import EpochBasedRunner, IterBasedRunner, build_optimizer
from mmengine.logging import MMLogger, print_log
from torch.utils.data import DataLoader, Dataset
from mmcls.core.hook import PreciseBNHook
class ExampleDataset(Dataset):
def __init__(self):
self.index = 0
def __getitem__(self, idx):
results = dict(imgs=torch.tensor([1.0], dtype=torch.float32))
return results
def __len__(self):
return 1
class BiggerDataset(ExampleDataset):
def __init__(self, fixed_values=range(0, 12)):
assert len(self) == len(fixed_values)
self.fixed_values = fixed_values
def __getitem__(self, idx):
results = dict(
imgs=torch.tensor([self.fixed_values[idx]], dtype=torch.float32))
return results
def __len__(self):
# a bigger dataset
return 12
class ExampleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Linear(1, 1)
self.bn = nn.BatchNorm1d(1)
self.test_cfg = None
def forward(self, imgs, return_loss=False):
return self.bn(self.conv(imgs))
def train_step(self, data_batch, optimizer, **kwargs):
outputs = {
'loss': 0.5,
'log_vars': {
'accuracy': 0.98
},
'num_samples': 1
}
return outputs
class SingleBNModel(ExampleModel):
def __init__(self):
super().__init__()
self.bn = nn.BatchNorm1d(1)
self.test_cfg = None
def forward(self, imgs, return_loss=False):
return self.bn(imgs)
class GNExampleModel(ExampleModel):
def __init__(self):
super().__init__()
self.conv = nn.Linear(1, 1)
self.bn = nn.GroupNorm(1, 1)
self.test_cfg = None
class NoBNExampleModel(ExampleModel):
def __init__(self):
super().__init__()
self.conv = nn.Linear(1, 1)
self.test_cfg = None
def forward(self, imgs, return_loss=False):
return self.conv(imgs)
def test_precise_bn():
optimizer_cfg = dict(
type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
test_dataset = ExampleDataset()
loader = DataLoader(test_dataset, batch_size=2)
model = ExampleModel()
optimizer = build_optimizer(model, optimizer_cfg)
logger = MMLogger.get_instance('precise_bn')
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
with pytest.raises(AssertionError):
# num_samples must be larger than 0
precise_bn_hook = PreciseBNHook(num_samples=-1)
runner.register_hook(precise_bn_hook)
runner.run([loader], [('train', 1)])
with pytest.raises(AssertionError):
# interval must be larger than 0
precise_bn_hook = PreciseBNHook(interval=0)
runner.register_hook(precise_bn_hook)
runner.run([loader], [('train', 1)])
with pytest.raises(AssertionError):
# interval must be larger than 0
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
precise_bn_hook = PreciseBNHook(interval=0)
runner.register_hook(precise_bn_hook)
runner.run([loader], [('train', 1)])
with pytest.raises(AssertionError):
# only support EpochBaseRunner
runner = IterBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
precise_bn_hook = PreciseBNHook(interval=2)
runner.register_hook(precise_bn_hook)
print_log(runner)
runner.run([loader], [('train', 1)])
# test non-DDP model
test_bigger_dataset = BiggerDataset()
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=4)
assert precise_bn_hook.num_samples == 4
assert precise_bn_hook.interval == 1
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])
# test DP model
test_bigger_dataset = BiggerDataset()
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=4)
assert precise_bn_hook.num_samples == 4
assert precise_bn_hook.interval == 1
model = MMDataParallel(model)
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])
# test model w/ gn layer
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=4)
assert precise_bn_hook.num_samples == 4
assert precise_bn_hook.interval == 1
model = GNExampleModel()
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])
# test model without bn layer
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=4)
assert precise_bn_hook.num_samples == 4
assert precise_bn_hook.interval == 1
model = NoBNExampleModel()
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])
# test how precise it is
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=12)
assert precise_bn_hook.num_samples == 12
assert precise_bn_hook.interval == 1
model = SingleBNModel()
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])
imgs_list = list()
for loader in loaders:
for i, data in enumerate(loader):
imgs_list.append(np.array(data['imgs']))
mean = np.mean([np.mean(batch) for batch in imgs_list])
# bassel correction used in Pytorch, therefore ddof=1
var = np.mean([np.var(batch, ddof=1) for batch in imgs_list])
assert np.equal(mean, np.array(
model.bn.running_mean)), (mean, np.array(model.bn.running_mean))
assert np.equal(var, np.array(
model.bn.running_var)), (var, np.array(model.bn.running_var))
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_ddp_model_precise_bn():
# test DDP model
test_bigger_dataset = BiggerDataset()
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=5)
assert precise_bn_hook.num_samples == 5
assert precise_bn_hook.interval == 1
model = ExampleModel()
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=True)
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import tempfile
from mmengine.logging import MMLogger
from mmcls.utils import get_root_logger, load_json_log

View File

@ -1,100 +0,0 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os
import os.path as osp
import tempfile
from unittest.mock import MagicMock
import matplotlib.pyplot as plt
import mmcv
import numpy as np
import pytest
from mmcls.core import visualization as vis
def test_color():
assert vis.color_val_matplotlib(mmcv.Color.blue) == (0., 0., 1.)
assert vis.color_val_matplotlib('green') == (0., 1., 0.)
assert vis.color_val_matplotlib((1, 2, 3)) == (3 / 255, 2 / 255, 1 / 255)
assert vis.color_val_matplotlib(100) == (100 / 255, 100 / 255, 100 / 255)
assert vis.color_val_matplotlib(np.zeros(3, dtype=int)) == (0., 0., 0.)
# forbid white color
with pytest.raises(TypeError):
vis.color_val_matplotlib([255, 255, 255])
# forbid float
with pytest.raises(TypeError):
vis.color_val_matplotlib(1.0)
# overflowed
with pytest.raises(AssertionError):
vis.color_val_matplotlib((0, 0, 500))
def test_imshow_infos():
tmp_dir = osp.join(tempfile.gettempdir(), 'image_infos')
tmp_filename = osp.join(tmp_dir, 'image.jpg')
image = np.ones((10, 10, 3), np.uint8)
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
out_image = vis.imshow_infos(
image, result, out_file=tmp_filename, show=False)
assert osp.isfile(tmp_filename)
assert image.shape == out_image.shape
assert not np.allclose(image, out_image)
os.remove(tmp_filename)
# test grayscale images
image = np.ones((10, 10), np.uint8)
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
out_image = vis.imshow_infos(
image, result, out_file=tmp_filename, show=False)
assert osp.isfile(tmp_filename)
assert image.shape == out_image.shape[:2]
os.remove(tmp_filename)
def test_figure_context_manager():
# test show multiple images with the same figure.
images = [
np.random.randint(0, 255, (100, 100, 3), np.uint8) for _ in range(5)
]
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
with vis.ImshowInfosContextManager() as manager:
fig_show = manager.fig_show
fig_save = manager.fig_save
# Test time out
fig_show.canvas.start_event_loop = MagicMock()
fig_show.canvas.end_event_loop = MagicMock()
for image in images:
ret, out_image = manager.put_img_infos(image, result, show=True)
assert ret == 0
assert image.shape == out_image.shape
assert not np.allclose(image, out_image)
assert fig_show is manager.fig_show
assert fig_save is manager.fig_save
# Test continue key
fig_show.canvas.start_event_loop = (
lambda _: fig_show.canvas.key_press_event(' '))
for image in images:
ret, out_image = manager.put_img_infos(image, result, show=True)
assert ret == 0
assert image.shape == out_image.shape
assert not np.allclose(image, out_image)
assert fig_show is manager.fig_show
assert fig_save is manager.fig_save
# Test close figure manually
fig_show = manager.fig_show
def destroy(*_, **__):
fig_show.canvas.close_event()
plt.close(fig_show)
fig_show.canvas.start_event_loop = destroy
ret, out_image = manager.put_img_infos(images[0], result, show=True)
assert ret == 1
assert image.shape == out_image.shape
assert not np.allclose(image, out_image)
assert fig_save is manager.fig_save

View File

@ -2,12 +2,12 @@
import argparse
import copy
import math
import pkg_resources
import re
from pathlib import Path
import mmcv
import numpy as np
import pkg_resources
from mmcv import Config, DictAction
from mmcv.utils import to_2tuple
from torch.nn import BatchNorm1d, BatchNorm2d, GroupNorm, LayerNorm