[Feature] Add Visualization hook
parent
c78b5597d8
commit
522ab1fd84
configs/_base_
mmcls/core
visualization
tests/test_core/test_hooks
tools
|
@ -17,6 +17,9 @@ default_hooks = dict(
|
|||
|
||||
# set sampler seed in distributed evrionment.
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
|
||||
# validation results visualization, set True to enable it.
|
||||
visualization=dict(type='VisualizationHook', enable=False),
|
||||
)
|
||||
|
||||
# configure environment
|
||||
|
|
|
@ -2,9 +2,11 @@
|
|||
from .class_num_check_hook import ClassNumCheckHook
|
||||
from .lr_updater import CosineAnnealingCooldownLrUpdaterHook
|
||||
from .precise_bn_hook import PreciseBNHook
|
||||
from .wandblogger_hook import MMClsWandbHook
|
||||
from .visualization_hook import VisualizationHook
|
||||
|
||||
__all__ = [
|
||||
'ClassNumCheckHook', 'PreciseBNHook',
|
||||
'CosineAnnealingCooldownLrUpdaterHook', 'MMClsWandbHook'
|
||||
'ClassNumCheckHook',
|
||||
'PreciseBNHook',
|
||||
'CosineAnnealingCooldownLrUpdaterHook',
|
||||
'VisualizationHook',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import os.path as osp
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from mmengine import FileClient
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.runner import EpochBasedTrainLoop, Runner
|
||||
from mmengine.visualization import Visualizer
|
||||
|
||||
from mmcls.core import ClsDataSample
|
||||
from mmcls.registry import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class VisualizationHook(Hook):
|
||||
"""Classification Visualization Hook. Used to visualize validation and
|
||||
testing prediction results.
|
||||
|
||||
- If ``out_dir`` is specified, all storage backends are ignored
|
||||
and save the image to the ``out_dir``.
|
||||
- If ``show`` is True, plot the result image in a window, please
|
||||
confirm you are able to access the graphical interface.
|
||||
|
||||
Args:
|
||||
enable (bool): Whether to enable this hook. Defaults to False.
|
||||
interval (int): The interval of samples to visualize. Defaults to 5000.
|
||||
show (bool): Whether to display the drawn image. Default to False.
|
||||
out_dir (str, optional): directory where painted images will be saved
|
||||
in the testing process. If None, handle with the backends of the
|
||||
visualizer. Defaults to None.
|
||||
**kwargs: other keyword arguments of
|
||||
:meth:`mmcls.core.ClsVisualizer.add_datasample`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
enable=False,
|
||||
interval: int = 5000,
|
||||
show: bool = False,
|
||||
out_dir: Optional[str] = None,
|
||||
**kwargs):
|
||||
self._visualizer: Visualizer = Visualizer.get_current_instance()
|
||||
|
||||
self.enable = enable
|
||||
self.interval = interval
|
||||
self.show = show
|
||||
self.out_dir = out_dir
|
||||
if out_dir is not None:
|
||||
self.file_client = FileClient.infer_client(uri=out_dir)
|
||||
else:
|
||||
self.file_client = None
|
||||
|
||||
self.draw_args = {**kwargs, 'show': show}
|
||||
|
||||
def _draw_samples(self,
|
||||
batch_idx: int,
|
||||
data_batch: Sequence[dict],
|
||||
outputs: Sequence[ClsDataSample],
|
||||
step: int = 0) -> None:
|
||||
"""Visualize every ``self.interval`` samples from a data batch.
|
||||
|
||||
Args:
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (Sequence[dict]): Data from dataloader.
|
||||
outputs (Sequence[:obj:`DetDataSample`]): Outputs from model.
|
||||
step (int): Global step value to record. Default to 0.
|
||||
"""
|
||||
if self.enable is False:
|
||||
return
|
||||
|
||||
batch_size = len(outputs)
|
||||
start_idx = batch_size * batch_idx
|
||||
end_idx = start_idx + batch_size
|
||||
|
||||
# The first index divisible by the interval, after the start index
|
||||
first_sample_id = math.ceil(start_idx / self.interval) * self.interval
|
||||
|
||||
for sample_id in range(first_sample_id, end_idx, self.interval):
|
||||
image = data_batch[sample_id - start_idx]['inputs']
|
||||
image = image.permute(1, 2, 0).numpy().astype('uint8')
|
||||
|
||||
data_sample = outputs[sample_id - start_idx]
|
||||
if 'img_path' in data_sample:
|
||||
# osp.basename works on different platforms even file clients.
|
||||
sample_name = osp.basename(data_sample.get('img_path'))
|
||||
else:
|
||||
sample_name = str(sample_id)
|
||||
|
||||
draw_args = self.draw_args
|
||||
if self.out_dir is not None:
|
||||
draw_args['out_file'] = self.file_client.join_path(
|
||||
self.out_dir, f'{sample_name}_{step}.png')
|
||||
|
||||
self._visualizer.add_datasample(
|
||||
sample_name,
|
||||
image=image,
|
||||
data_sample=data_sample,
|
||||
step=step,
|
||||
**self.draw_args,
|
||||
)
|
||||
|
||||
def after_val_iter(self, runner: Runner, batch_idx: int,
|
||||
data_batch: Sequence[dict],
|
||||
outputs: Sequence[ClsDataSample]) -> None:
|
||||
"""Visualize every ``self.interval`` samples during validation.
|
||||
|
||||
Args:
|
||||
runner (:obj:`Runner`): The runner of the validation process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (Sequence[dict]): Data from dataloader.
|
||||
outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model.
|
||||
"""
|
||||
if isinstance(runner.train_loop, EpochBasedTrainLoop):
|
||||
step = runner.epoch
|
||||
else:
|
||||
step = runner.iter
|
||||
|
||||
self._draw_samples(batch_idx, data_batch, outputs, step=step)
|
||||
|
||||
def after_test_iter(self, runner: Runner, batch_idx: int,
|
||||
data_batch: Sequence[dict],
|
||||
outputs: Sequence[ClsDataSample]) -> None:
|
||||
"""Visualize every ``self.interval`` samples during test.
|
||||
|
||||
Args:
|
||||
runner (:obj:`Runner`): The runner of the testing process.
|
||||
batch_idx (int): The index of the current batch in the test loop.
|
||||
data_batch (Sequence[dict]): Data from dataloader.
|
||||
outputs (Sequence[:obj:`DetDataSample`]): Outputs from model.
|
||||
"""
|
||||
self._draw_samples(batch_idx, data_batch, outputs, step=0)
|
|
@ -95,11 +95,10 @@ class ClsVisualizer(Visualizer):
|
|||
step: int = 0) -> None:
|
||||
"""Draw datasample and save to all backends.
|
||||
|
||||
- If ``show`` is True, all storage backends are ignored and then
|
||||
displayed in a local window.
|
||||
- If the ``out_file`` parameter is specified, the drawn image
|
||||
will be additionally saved to ``out_file``. It is usually used
|
||||
in script mode like ``image_demo.py``
|
||||
- If ``out_file`` is specified, all storage backends are ignored
|
||||
and save the image to the ``out_file``.
|
||||
- If ``show`` is True, plot the result image in a window, please
|
||||
confirm you are able to access the graphical interface.
|
||||
|
||||
Args:
|
||||
name (str): The image identifier.
|
||||
|
@ -121,8 +120,9 @@ class ClsVisualizer(Visualizer):
|
|||
wait_time (float): The interval of show (s). Default to 0, which
|
||||
means "forever".
|
||||
out_file (str, optional): Extra path to save the visualization
|
||||
result. Whether specified or not, the visualizer will still
|
||||
save the results by its storage backends. Default to None.
|
||||
result. If specified, the visualizer will only save the result
|
||||
image to the out_file and ignore its storage backends.
|
||||
Default to None.
|
||||
step (int): Global step value to record. Default to 0.
|
||||
"""
|
||||
classes = None
|
||||
|
@ -179,8 +179,9 @@ class ClsVisualizer(Visualizer):
|
|||
|
||||
if show:
|
||||
self.show(drawn_img, win_name=name, wait_time=wait_time)
|
||||
else:
|
||||
self.add_image(name, drawn_img, step=step)
|
||||
|
||||
if out_file is not None:
|
||||
# save the image to the target file instead of vis_backends
|
||||
mmcv.imwrite(drawn_img[..., ::-1], out_file)
|
||||
else:
|
||||
self.add_image(name, drawn_img, step=step)
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import torch
|
||||
from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop
|
||||
|
||||
from mmcls.core import ClsDataSample, ClsVisualizer, VisualizationHook
|
||||
from mmcls.registry import HOOKS
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestVisualizationHook(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
ClsVisualizer.get_instance('visualizer')
|
||||
|
||||
data_sample = ClsDataSample().set_gt_label(1).set_pred_label(2)
|
||||
data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
|
||||
self.data_batch = [{
|
||||
'inputs': torch.randint(0, 256, (3, 224, 224)),
|
||||
'data_sample': data_sample
|
||||
}] * 10
|
||||
|
||||
self.outputs = [data_sample] * 10
|
||||
|
||||
self.tmpdir = tempfile.TemporaryDirectory()
|
||||
|
||||
def test_initialize(self):
|
||||
# test file_client
|
||||
cfg = dict(type='VisualizationHook')
|
||||
hook = HOOKS.build(cfg)
|
||||
self.assertIsNone(hook.file_client)
|
||||
|
||||
cfg = dict(type='VisualizationHook', out_dir=self.tmpdir.name)
|
||||
hook = HOOKS.build(cfg)
|
||||
self.assertIsNotNone(hook.file_client)
|
||||
|
||||
# test draw_args
|
||||
|
||||
def test_draw_samples(self):
|
||||
# test enable=False
|
||||
cfg = dict(type='VisualizationHook', enable=False)
|
||||
hook: VisualizationHook = HOOKS.build(cfg)
|
||||
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||
hook._draw_samples(1, self.data_batch, self.outputs, step=1)
|
||||
mock.assert_not_called()
|
||||
|
||||
# test enable=True
|
||||
cfg = dict(type='VisualizationHook', enable=True, show=True)
|
||||
hook: VisualizationHook = HOOKS.build(cfg)
|
||||
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||
hook._draw_samples(0, self.data_batch, self.outputs, step=0)
|
||||
mock.assert_called_once_with(
|
||||
'color.jpg',
|
||||
image=ANY,
|
||||
data_sample=self.outputs[0],
|
||||
step=0,
|
||||
show=True)
|
||||
|
||||
# test samples without path
|
||||
cfg = dict(type='VisualizationHook', enable=True)
|
||||
hook: VisualizationHook = HOOKS.build(cfg)
|
||||
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||
outputs = [ClsDataSample()] * 10
|
||||
hook._draw_samples(0, self.data_batch, outputs, step=0)
|
||||
mock.assert_called_once_with(
|
||||
'0', image=ANY, data_sample=outputs[0], step=0, show=False)
|
||||
|
||||
# test out_dir
|
||||
cfg = dict(
|
||||
type='VisualizationHook', enable=True, out_dir=self.tmpdir.name)
|
||||
hook: VisualizationHook = HOOKS.build(cfg)
|
||||
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||
hook._draw_samples(0, self.data_batch, self.outputs, step=0)
|
||||
mock.assert_called_once_with(
|
||||
'color.jpg',
|
||||
image=ANY,
|
||||
data_sample=self.outputs[0],
|
||||
step=0,
|
||||
show=False,
|
||||
out_file=osp.join(self.tmpdir.name, 'color.jpg_0.png'))
|
||||
|
||||
# test sample idx
|
||||
cfg = dict(type='VisualizationHook', enable=True, interval=4)
|
||||
hook: VisualizationHook = HOOKS.build(cfg)
|
||||
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||
hook._draw_samples(1, self.data_batch, self.outputs, step=0)
|
||||
mock.assert_called_with(
|
||||
'color.jpg',
|
||||
image=ANY,
|
||||
data_sample=self.outputs[2],
|
||||
step=0,
|
||||
show=False)
|
||||
mock.assert_called_with(
|
||||
'color.jpg',
|
||||
image=ANY,
|
||||
data_sample=self.outputs[6],
|
||||
step=0,
|
||||
show=False)
|
||||
|
||||
def test_after_val_iter(self):
|
||||
runner = MagicMock()
|
||||
|
||||
# test epoch-based
|
||||
runner.train_loop = MagicMock(spec=EpochBasedTrainLoop)
|
||||
runner.epoch = 5
|
||||
cfg = dict(type='VisualizationHook', enable=True)
|
||||
hook = HOOKS.build(cfg)
|
||||
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||
hook.after_val_iter(runner, 0, self.data_batch, self.outputs)
|
||||
mock.assert_called_once_with(
|
||||
'color.jpg',
|
||||
image=ANY,
|
||||
data_sample=self.outputs[0],
|
||||
step=5,
|
||||
show=False)
|
||||
|
||||
# test iter-based
|
||||
runner.train_loop = MagicMock(spec=IterBasedTrainLoop)
|
||||
runner.iter = 300
|
||||
cfg = dict(type='VisualizationHook', enable=True)
|
||||
hook = HOOKS.build(cfg)
|
||||
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||
hook.after_val_iter(runner, 0, self.data_batch, self.outputs)
|
||||
mock.assert_called_once_with(
|
||||
'color.jpg',
|
||||
image=ANY,
|
||||
data_sample=self.outputs[0],
|
||||
step=300,
|
||||
show=False)
|
||||
|
||||
def test_after_test_iter(self):
|
||||
runner = MagicMock()
|
||||
|
||||
cfg = dict(type='VisualizationHook', enable=True)
|
||||
hook = HOOKS.build(cfg)
|
||||
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||
hook.after_test_iter(runner, 0, self.data_batch, self.outputs)
|
||||
mock.assert_called_once_with(
|
||||
'color.jpg',
|
||||
image=ANY,
|
||||
data_sample=self.outputs[0],
|
||||
step=0,
|
||||
show=False)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.tmpdir.cleanup()
|
|
@ -27,6 +27,23 @@ def parse_args():
|
|||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
parser.add_argument(
|
||||
'--show-dir',
|
||||
help='directory where the visualization images will be saved.')
|
||||
parser.add_argument(
|
||||
'--show',
|
||||
action='store_true',
|
||||
help='whether to display the prediction results in a window.')
|
||||
parser.add_argument(
|
||||
'--interval',
|
||||
type=int,
|
||||
default=1,
|
||||
help='visualize per interval samples.')
|
||||
parser.add_argument(
|
||||
'--wait-time',
|
||||
type=float,
|
||||
default=2,
|
||||
help='display time of every window. (second)')
|
||||
parser.add_argument(
|
||||
'--launcher',
|
||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
|
@ -39,6 +56,23 @@ def parse_args():
|
|||
return args
|
||||
|
||||
|
||||
def merge_args(cfg, args):
|
||||
"""Merge CLI arguments to config."""
|
||||
# -------------------- visualization --------------------
|
||||
if args.show or (args.show_dir is not None):
|
||||
assert 'visualization' in cfg.default_hooks, \
|
||||
'VisualizationHook is not set in the `default_hooks` field of ' \
|
||||
'config. Please set `visualization=dict(type="VisualizationHook")`'
|
||||
|
||||
cfg.default_hooks.visualization.enable = True
|
||||
cfg.default_hooks.visualization.show = args.show
|
||||
cfg.default_hooks.visualization.wait_time = args.wait_time
|
||||
cfg.default_hooks.visualization.out_dir = args.show_dir
|
||||
cfg.default_hooks.visualization.interval = args.interval
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
|
@ -48,6 +82,7 @@ def main():
|
|||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg = merge_args(cfg, args)
|
||||
cfg.launcher = args.launcher
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
|
Loading…
Reference in New Issue