[Feature] Add Visualization hook

pull/913/head
mzr1996 2022-06-06 02:32:22 +00:00
parent c78b5597d8
commit 522ab1fd84
6 changed files with 336 additions and 12 deletions
configs/_base_
tests/test_core/test_hooks
tools

View File

@ -17,6 +17,9 @@ default_hooks = dict(
# set sampler seed in distributed evrionment.
sampler_seed=dict(type='DistSamplerSeedHook'),
# validation results visualization, set True to enable it.
visualization=dict(type='VisualizationHook', enable=False),
)
# configure environment

View File

@ -2,9 +2,11 @@
from .class_num_check_hook import ClassNumCheckHook
from .lr_updater import CosineAnnealingCooldownLrUpdaterHook
from .precise_bn_hook import PreciseBNHook
from .wandblogger_hook import MMClsWandbHook
from .visualization_hook import VisualizationHook
__all__ = [
'ClassNumCheckHook', 'PreciseBNHook',
'CosineAnnealingCooldownLrUpdaterHook', 'MMClsWandbHook'
'ClassNumCheckHook',
'PreciseBNHook',
'CosineAnnealingCooldownLrUpdaterHook',
'VisualizationHook',
]

View File

@ -0,0 +1,131 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import os.path as osp
from typing import Optional, Sequence
from mmengine import FileClient
from mmengine.hooks import Hook
from mmengine.runner import EpochBasedTrainLoop, Runner
from mmengine.visualization import Visualizer
from mmcls.core import ClsDataSample
from mmcls.registry import HOOKS
@HOOKS.register_module()
class VisualizationHook(Hook):
"""Classification Visualization Hook. Used to visualize validation and
testing prediction results.
- If ``out_dir`` is specified, all storage backends are ignored
and save the image to the ``out_dir``.
- If ``show`` is True, plot the result image in a window, please
confirm you are able to access the graphical interface.
Args:
enable (bool): Whether to enable this hook. Defaults to False.
interval (int): The interval of samples to visualize. Defaults to 5000.
show (bool): Whether to display the drawn image. Default to False.
out_dir (str, optional): directory where painted images will be saved
in the testing process. If None, handle with the backends of the
visualizer. Defaults to None.
**kwargs: other keyword arguments of
:meth:`mmcls.core.ClsVisualizer.add_datasample`.
"""
def __init__(self,
enable=False,
interval: int = 5000,
show: bool = False,
out_dir: Optional[str] = None,
**kwargs):
self._visualizer: Visualizer = Visualizer.get_current_instance()
self.enable = enable
self.interval = interval
self.show = show
self.out_dir = out_dir
if out_dir is not None:
self.file_client = FileClient.infer_client(uri=out_dir)
else:
self.file_client = None
self.draw_args = {**kwargs, 'show': show}
def _draw_samples(self,
batch_idx: int,
data_batch: Sequence[dict],
outputs: Sequence[ClsDataSample],
step: int = 0) -> None:
"""Visualize every ``self.interval`` samples from a data batch.
Args:
batch_idx (int): The index of the current batch in the val loop.
data_batch (Sequence[dict]): Data from dataloader.
outputs (Sequence[:obj:`DetDataSample`]): Outputs from model.
step (int): Global step value to record. Default to 0.
"""
if self.enable is False:
return
batch_size = len(outputs)
start_idx = batch_size * batch_idx
end_idx = start_idx + batch_size
# The first index divisible by the interval, after the start index
first_sample_id = math.ceil(start_idx / self.interval) * self.interval
for sample_id in range(first_sample_id, end_idx, self.interval):
image = data_batch[sample_id - start_idx]['inputs']
image = image.permute(1, 2, 0).numpy().astype('uint8')
data_sample = outputs[sample_id - start_idx]
if 'img_path' in data_sample:
# osp.basename works on different platforms even file clients.
sample_name = osp.basename(data_sample.get('img_path'))
else:
sample_name = str(sample_id)
draw_args = self.draw_args
if self.out_dir is not None:
draw_args['out_file'] = self.file_client.join_path(
self.out_dir, f'{sample_name}_{step}.png')
self._visualizer.add_datasample(
sample_name,
image=image,
data_sample=data_sample,
step=step,
**self.draw_args,
)
def after_val_iter(self, runner: Runner, batch_idx: int,
data_batch: Sequence[dict],
outputs: Sequence[ClsDataSample]) -> None:
"""Visualize every ``self.interval`` samples during validation.
Args:
runner (:obj:`Runner`): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (Sequence[dict]): Data from dataloader.
outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model.
"""
if isinstance(runner.train_loop, EpochBasedTrainLoop):
step = runner.epoch
else:
step = runner.iter
self._draw_samples(batch_idx, data_batch, outputs, step=step)
def after_test_iter(self, runner: Runner, batch_idx: int,
data_batch: Sequence[dict],
outputs: Sequence[ClsDataSample]) -> None:
"""Visualize every ``self.interval`` samples during test.
Args:
runner (:obj:`Runner`): The runner of the testing process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[dict]): Data from dataloader.
outputs (Sequence[:obj:`DetDataSample`]): Outputs from model.
"""
self._draw_samples(batch_idx, data_batch, outputs, step=0)

View File

@ -95,11 +95,10 @@ class ClsVisualizer(Visualizer):
step: int = 0) -> None:
"""Draw datasample and save to all backends.
- If ``show`` is True, all storage backends are ignored and then
displayed in a local window.
- If the ``out_file`` parameter is specified, the drawn image
will be additionally saved to ``out_file``. It is usually used
in script mode like ``image_demo.py``
- If ``out_file`` is specified, all storage backends are ignored
and save the image to the ``out_file``.
- If ``show`` is True, plot the result image in a window, please
confirm you are able to access the graphical interface.
Args:
name (str): The image identifier.
@ -121,8 +120,9 @@ class ClsVisualizer(Visualizer):
wait_time (float): The interval of show (s). Default to 0, which
means "forever".
out_file (str, optional): Extra path to save the visualization
result. Whether specified or not, the visualizer will still
save the results by its storage backends. Default to None.
result. If specified, the visualizer will only save the result
image to the out_file and ignore its storage backends.
Default to None.
step (int): Global step value to record. Default to 0.
"""
classes = None
@ -179,8 +179,9 @@ class ClsVisualizer(Visualizer):
if show:
self.show(drawn_img, win_name=name, wait_time=wait_time)
else:
self.add_image(name, drawn_img, step=step)
if out_file is not None:
# save the image to the target file instead of vis_backends
mmcv.imwrite(drawn_img[..., ::-1], out_file)
else:
self.add_image(name, drawn_img, step=step)

View File

@ -0,0 +1,152 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase
from unittest.mock import ANY, MagicMock, patch
import torch
from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop
from mmcls.core import ClsDataSample, ClsVisualizer, VisualizationHook
from mmcls.registry import HOOKS
from mmcls.utils import register_all_modules
register_all_modules()
class TestVisualizationHook(TestCase):
def setUp(self) -> None:
ClsVisualizer.get_instance('visualizer')
data_sample = ClsDataSample().set_gt_label(1).set_pred_label(2)
data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
self.data_batch = [{
'inputs': torch.randint(0, 256, (3, 224, 224)),
'data_sample': data_sample
}] * 10
self.outputs = [data_sample] * 10
self.tmpdir = tempfile.TemporaryDirectory()
def test_initialize(self):
# test file_client
cfg = dict(type='VisualizationHook')
hook = HOOKS.build(cfg)
self.assertIsNone(hook.file_client)
cfg = dict(type='VisualizationHook', out_dir=self.tmpdir.name)
hook = HOOKS.build(cfg)
self.assertIsNotNone(hook.file_client)
# test draw_args
def test_draw_samples(self):
# test enable=False
cfg = dict(type='VisualizationHook', enable=False)
hook: VisualizationHook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
hook._draw_samples(1, self.data_batch, self.outputs, step=1)
mock.assert_not_called()
# test enable=True
cfg = dict(type='VisualizationHook', enable=True, show=True)
hook: VisualizationHook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
hook._draw_samples(0, self.data_batch, self.outputs, step=0)
mock.assert_called_once_with(
'color.jpg',
image=ANY,
data_sample=self.outputs[0],
step=0,
show=True)
# test samples without path
cfg = dict(type='VisualizationHook', enable=True)
hook: VisualizationHook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
outputs = [ClsDataSample()] * 10
hook._draw_samples(0, self.data_batch, outputs, step=0)
mock.assert_called_once_with(
'0', image=ANY, data_sample=outputs[0], step=0, show=False)
# test out_dir
cfg = dict(
type='VisualizationHook', enable=True, out_dir=self.tmpdir.name)
hook: VisualizationHook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
hook._draw_samples(0, self.data_batch, self.outputs, step=0)
mock.assert_called_once_with(
'color.jpg',
image=ANY,
data_sample=self.outputs[0],
step=0,
show=False,
out_file=osp.join(self.tmpdir.name, 'color.jpg_0.png'))
# test sample idx
cfg = dict(type='VisualizationHook', enable=True, interval=4)
hook: VisualizationHook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
hook._draw_samples(1, self.data_batch, self.outputs, step=0)
mock.assert_called_with(
'color.jpg',
image=ANY,
data_sample=self.outputs[2],
step=0,
show=False)
mock.assert_called_with(
'color.jpg',
image=ANY,
data_sample=self.outputs[6],
step=0,
show=False)
def test_after_val_iter(self):
runner = MagicMock()
# test epoch-based
runner.train_loop = MagicMock(spec=EpochBasedTrainLoop)
runner.epoch = 5
cfg = dict(type='VisualizationHook', enable=True)
hook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
hook.after_val_iter(runner, 0, self.data_batch, self.outputs)
mock.assert_called_once_with(
'color.jpg',
image=ANY,
data_sample=self.outputs[0],
step=5,
show=False)
# test iter-based
runner.train_loop = MagicMock(spec=IterBasedTrainLoop)
runner.iter = 300
cfg = dict(type='VisualizationHook', enable=True)
hook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
hook.after_val_iter(runner, 0, self.data_batch, self.outputs)
mock.assert_called_once_with(
'color.jpg',
image=ANY,
data_sample=self.outputs[0],
step=300,
show=False)
def test_after_test_iter(self):
runner = MagicMock()
cfg = dict(type='VisualizationHook', enable=True)
hook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
hook.after_test_iter(runner, 0, self.data_batch, self.outputs)
mock.assert_called_once_with(
'color.jpg',
image=ANY,
data_sample=self.outputs[0],
step=0,
show=False)
def tearDown(self) -> None:
self.tmpdir.cleanup()

View File

@ -27,6 +27,23 @@ def parse_args():
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--show-dir',
help='directory where the visualization images will be saved.')
parser.add_argument(
'--show',
action='store_true',
help='whether to display the prediction results in a window.')
parser.add_argument(
'--interval',
type=int,
default=1,
help='visualize per interval samples.')
parser.add_argument(
'--wait-time',
type=float,
default=2,
help='display time of every window. (second)')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
@ -39,6 +56,23 @@ def parse_args():
return args
def merge_args(cfg, args):
"""Merge CLI arguments to config."""
# -------------------- visualization --------------------
if args.show or (args.show_dir is not None):
assert 'visualization' in cfg.default_hooks, \
'VisualizationHook is not set in the `default_hooks` field of ' \
'config. Please set `visualization=dict(type="VisualizationHook")`'
cfg.default_hooks.visualization.enable = True
cfg.default_hooks.visualization.show = args.show
cfg.default_hooks.visualization.wait_time = args.wait_time
cfg.default_hooks.visualization.out_dir = args.show_dir
cfg.default_hooks.visualization.interval = args.interval
return cfg
def main():
args = parse_args()
@ -48,6 +82,7 @@ def main():
# load config
cfg = Config.fromfile(args.config)
cfg = merge_args(cfg, args)
cfg.launcher = args.launcher
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)