[Feature] Add Visualization hook
parent
c78b5597d8
commit
522ab1fd84
|
@ -17,6 +17,9 @@ default_hooks = dict(
|
||||||
|
|
||||||
# set sampler seed in distributed evrionment.
|
# set sampler seed in distributed evrionment.
|
||||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||||
|
|
||||||
|
# validation results visualization, set True to enable it.
|
||||||
|
visualization=dict(type='VisualizationHook', enable=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
# configure environment
|
# configure environment
|
||||||
|
|
|
@ -2,9 +2,11 @@
|
||||||
from .class_num_check_hook import ClassNumCheckHook
|
from .class_num_check_hook import ClassNumCheckHook
|
||||||
from .lr_updater import CosineAnnealingCooldownLrUpdaterHook
|
from .lr_updater import CosineAnnealingCooldownLrUpdaterHook
|
||||||
from .precise_bn_hook import PreciseBNHook
|
from .precise_bn_hook import PreciseBNHook
|
||||||
from .wandblogger_hook import MMClsWandbHook
|
from .visualization_hook import VisualizationHook
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ClassNumCheckHook', 'PreciseBNHook',
|
'ClassNumCheckHook',
|
||||||
'CosineAnnealingCooldownLrUpdaterHook', 'MMClsWandbHook'
|
'PreciseBNHook',
|
||||||
|
'CosineAnnealingCooldownLrUpdaterHook',
|
||||||
|
'VisualizationHook',
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,131 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import math
|
||||||
|
import os.path as osp
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
from mmengine import FileClient
|
||||||
|
from mmengine.hooks import Hook
|
||||||
|
from mmengine.runner import EpochBasedTrainLoop, Runner
|
||||||
|
from mmengine.visualization import Visualizer
|
||||||
|
|
||||||
|
from mmcls.core import ClsDataSample
|
||||||
|
from mmcls.registry import HOOKS
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class VisualizationHook(Hook):
|
||||||
|
"""Classification Visualization Hook. Used to visualize validation and
|
||||||
|
testing prediction results.
|
||||||
|
|
||||||
|
- If ``out_dir`` is specified, all storage backends are ignored
|
||||||
|
and save the image to the ``out_dir``.
|
||||||
|
- If ``show`` is True, plot the result image in a window, please
|
||||||
|
confirm you are able to access the graphical interface.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enable (bool): Whether to enable this hook. Defaults to False.
|
||||||
|
interval (int): The interval of samples to visualize. Defaults to 5000.
|
||||||
|
show (bool): Whether to display the drawn image. Default to False.
|
||||||
|
out_dir (str, optional): directory where painted images will be saved
|
||||||
|
in the testing process. If None, handle with the backends of the
|
||||||
|
visualizer. Defaults to None.
|
||||||
|
**kwargs: other keyword arguments of
|
||||||
|
:meth:`mmcls.core.ClsVisualizer.add_datasample`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
enable=False,
|
||||||
|
interval: int = 5000,
|
||||||
|
show: bool = False,
|
||||||
|
out_dir: Optional[str] = None,
|
||||||
|
**kwargs):
|
||||||
|
self._visualizer: Visualizer = Visualizer.get_current_instance()
|
||||||
|
|
||||||
|
self.enable = enable
|
||||||
|
self.interval = interval
|
||||||
|
self.show = show
|
||||||
|
self.out_dir = out_dir
|
||||||
|
if out_dir is not None:
|
||||||
|
self.file_client = FileClient.infer_client(uri=out_dir)
|
||||||
|
else:
|
||||||
|
self.file_client = None
|
||||||
|
|
||||||
|
self.draw_args = {**kwargs, 'show': show}
|
||||||
|
|
||||||
|
def _draw_samples(self,
|
||||||
|
batch_idx: int,
|
||||||
|
data_batch: Sequence[dict],
|
||||||
|
outputs: Sequence[ClsDataSample],
|
||||||
|
step: int = 0) -> None:
|
||||||
|
"""Visualize every ``self.interval`` samples from a data batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_idx (int): The index of the current batch in the val loop.
|
||||||
|
data_batch (Sequence[dict]): Data from dataloader.
|
||||||
|
outputs (Sequence[:obj:`DetDataSample`]): Outputs from model.
|
||||||
|
step (int): Global step value to record. Default to 0.
|
||||||
|
"""
|
||||||
|
if self.enable is False:
|
||||||
|
return
|
||||||
|
|
||||||
|
batch_size = len(outputs)
|
||||||
|
start_idx = batch_size * batch_idx
|
||||||
|
end_idx = start_idx + batch_size
|
||||||
|
|
||||||
|
# The first index divisible by the interval, after the start index
|
||||||
|
first_sample_id = math.ceil(start_idx / self.interval) * self.interval
|
||||||
|
|
||||||
|
for sample_id in range(first_sample_id, end_idx, self.interval):
|
||||||
|
image = data_batch[sample_id - start_idx]['inputs']
|
||||||
|
image = image.permute(1, 2, 0).numpy().astype('uint8')
|
||||||
|
|
||||||
|
data_sample = outputs[sample_id - start_idx]
|
||||||
|
if 'img_path' in data_sample:
|
||||||
|
# osp.basename works on different platforms even file clients.
|
||||||
|
sample_name = osp.basename(data_sample.get('img_path'))
|
||||||
|
else:
|
||||||
|
sample_name = str(sample_id)
|
||||||
|
|
||||||
|
draw_args = self.draw_args
|
||||||
|
if self.out_dir is not None:
|
||||||
|
draw_args['out_file'] = self.file_client.join_path(
|
||||||
|
self.out_dir, f'{sample_name}_{step}.png')
|
||||||
|
|
||||||
|
self._visualizer.add_datasample(
|
||||||
|
sample_name,
|
||||||
|
image=image,
|
||||||
|
data_sample=data_sample,
|
||||||
|
step=step,
|
||||||
|
**self.draw_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
def after_val_iter(self, runner: Runner, batch_idx: int,
|
||||||
|
data_batch: Sequence[dict],
|
||||||
|
outputs: Sequence[ClsDataSample]) -> None:
|
||||||
|
"""Visualize every ``self.interval`` samples during validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (:obj:`Runner`): The runner of the validation process.
|
||||||
|
batch_idx (int): The index of the current batch in the val loop.
|
||||||
|
data_batch (Sequence[dict]): Data from dataloader.
|
||||||
|
outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model.
|
||||||
|
"""
|
||||||
|
if isinstance(runner.train_loop, EpochBasedTrainLoop):
|
||||||
|
step = runner.epoch
|
||||||
|
else:
|
||||||
|
step = runner.iter
|
||||||
|
|
||||||
|
self._draw_samples(batch_idx, data_batch, outputs, step=step)
|
||||||
|
|
||||||
|
def after_test_iter(self, runner: Runner, batch_idx: int,
|
||||||
|
data_batch: Sequence[dict],
|
||||||
|
outputs: Sequence[ClsDataSample]) -> None:
|
||||||
|
"""Visualize every ``self.interval`` samples during test.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (:obj:`Runner`): The runner of the testing process.
|
||||||
|
batch_idx (int): The index of the current batch in the test loop.
|
||||||
|
data_batch (Sequence[dict]): Data from dataloader.
|
||||||
|
outputs (Sequence[:obj:`DetDataSample`]): Outputs from model.
|
||||||
|
"""
|
||||||
|
self._draw_samples(batch_idx, data_batch, outputs, step=0)
|
|
@ -95,11 +95,10 @@ class ClsVisualizer(Visualizer):
|
||||||
step: int = 0) -> None:
|
step: int = 0) -> None:
|
||||||
"""Draw datasample and save to all backends.
|
"""Draw datasample and save to all backends.
|
||||||
|
|
||||||
- If ``show`` is True, all storage backends are ignored and then
|
- If ``out_file`` is specified, all storage backends are ignored
|
||||||
displayed in a local window.
|
and save the image to the ``out_file``.
|
||||||
- If the ``out_file`` parameter is specified, the drawn image
|
- If ``show`` is True, plot the result image in a window, please
|
||||||
will be additionally saved to ``out_file``. It is usually used
|
confirm you are able to access the graphical interface.
|
||||||
in script mode like ``image_demo.py``
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): The image identifier.
|
name (str): The image identifier.
|
||||||
|
@ -121,8 +120,9 @@ class ClsVisualizer(Visualizer):
|
||||||
wait_time (float): The interval of show (s). Default to 0, which
|
wait_time (float): The interval of show (s). Default to 0, which
|
||||||
means "forever".
|
means "forever".
|
||||||
out_file (str, optional): Extra path to save the visualization
|
out_file (str, optional): Extra path to save the visualization
|
||||||
result. Whether specified or not, the visualizer will still
|
result. If specified, the visualizer will only save the result
|
||||||
save the results by its storage backends. Default to None.
|
image to the out_file and ignore its storage backends.
|
||||||
|
Default to None.
|
||||||
step (int): Global step value to record. Default to 0.
|
step (int): Global step value to record. Default to 0.
|
||||||
"""
|
"""
|
||||||
classes = None
|
classes = None
|
||||||
|
@ -179,8 +179,9 @@ class ClsVisualizer(Visualizer):
|
||||||
|
|
||||||
if show:
|
if show:
|
||||||
self.show(drawn_img, win_name=name, wait_time=wait_time)
|
self.show(drawn_img, win_name=name, wait_time=wait_time)
|
||||||
else:
|
|
||||||
self.add_image(name, drawn_img, step=step)
|
|
||||||
|
|
||||||
if out_file is not None:
|
if out_file is not None:
|
||||||
|
# save the image to the target file instead of vis_backends
|
||||||
mmcv.imwrite(drawn_img[..., ::-1], out_file)
|
mmcv.imwrite(drawn_img[..., ::-1], out_file)
|
||||||
|
else:
|
||||||
|
self.add_image(name, drawn_img, step=step)
|
||||||
|
|
|
@ -0,0 +1,152 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os.path as osp
|
||||||
|
import tempfile
|
||||||
|
from unittest import TestCase
|
||||||
|
from unittest.mock import ANY, MagicMock, patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop
|
||||||
|
|
||||||
|
from mmcls.core import ClsDataSample, ClsVisualizer, VisualizationHook
|
||||||
|
from mmcls.registry import HOOKS
|
||||||
|
from mmcls.utils import register_all_modules
|
||||||
|
|
||||||
|
register_all_modules()
|
||||||
|
|
||||||
|
|
||||||
|
class TestVisualizationHook(TestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
ClsVisualizer.get_instance('visualizer')
|
||||||
|
|
||||||
|
data_sample = ClsDataSample().set_gt_label(1).set_pred_label(2)
|
||||||
|
data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
|
||||||
|
self.data_batch = [{
|
||||||
|
'inputs': torch.randint(0, 256, (3, 224, 224)),
|
||||||
|
'data_sample': data_sample
|
||||||
|
}] * 10
|
||||||
|
|
||||||
|
self.outputs = [data_sample] * 10
|
||||||
|
|
||||||
|
self.tmpdir = tempfile.TemporaryDirectory()
|
||||||
|
|
||||||
|
def test_initialize(self):
|
||||||
|
# test file_client
|
||||||
|
cfg = dict(type='VisualizationHook')
|
||||||
|
hook = HOOKS.build(cfg)
|
||||||
|
self.assertIsNone(hook.file_client)
|
||||||
|
|
||||||
|
cfg = dict(type='VisualizationHook', out_dir=self.tmpdir.name)
|
||||||
|
hook = HOOKS.build(cfg)
|
||||||
|
self.assertIsNotNone(hook.file_client)
|
||||||
|
|
||||||
|
# test draw_args
|
||||||
|
|
||||||
|
def test_draw_samples(self):
|
||||||
|
# test enable=False
|
||||||
|
cfg = dict(type='VisualizationHook', enable=False)
|
||||||
|
hook: VisualizationHook = HOOKS.build(cfg)
|
||||||
|
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||||
|
hook._draw_samples(1, self.data_batch, self.outputs, step=1)
|
||||||
|
mock.assert_not_called()
|
||||||
|
|
||||||
|
# test enable=True
|
||||||
|
cfg = dict(type='VisualizationHook', enable=True, show=True)
|
||||||
|
hook: VisualizationHook = HOOKS.build(cfg)
|
||||||
|
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||||
|
hook._draw_samples(0, self.data_batch, self.outputs, step=0)
|
||||||
|
mock.assert_called_once_with(
|
||||||
|
'color.jpg',
|
||||||
|
image=ANY,
|
||||||
|
data_sample=self.outputs[0],
|
||||||
|
step=0,
|
||||||
|
show=True)
|
||||||
|
|
||||||
|
# test samples without path
|
||||||
|
cfg = dict(type='VisualizationHook', enable=True)
|
||||||
|
hook: VisualizationHook = HOOKS.build(cfg)
|
||||||
|
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||||
|
outputs = [ClsDataSample()] * 10
|
||||||
|
hook._draw_samples(0, self.data_batch, outputs, step=0)
|
||||||
|
mock.assert_called_once_with(
|
||||||
|
'0', image=ANY, data_sample=outputs[0], step=0, show=False)
|
||||||
|
|
||||||
|
# test out_dir
|
||||||
|
cfg = dict(
|
||||||
|
type='VisualizationHook', enable=True, out_dir=self.tmpdir.name)
|
||||||
|
hook: VisualizationHook = HOOKS.build(cfg)
|
||||||
|
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||||
|
hook._draw_samples(0, self.data_batch, self.outputs, step=0)
|
||||||
|
mock.assert_called_once_with(
|
||||||
|
'color.jpg',
|
||||||
|
image=ANY,
|
||||||
|
data_sample=self.outputs[0],
|
||||||
|
step=0,
|
||||||
|
show=False,
|
||||||
|
out_file=osp.join(self.tmpdir.name, 'color.jpg_0.png'))
|
||||||
|
|
||||||
|
# test sample idx
|
||||||
|
cfg = dict(type='VisualizationHook', enable=True, interval=4)
|
||||||
|
hook: VisualizationHook = HOOKS.build(cfg)
|
||||||
|
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||||
|
hook._draw_samples(1, self.data_batch, self.outputs, step=0)
|
||||||
|
mock.assert_called_with(
|
||||||
|
'color.jpg',
|
||||||
|
image=ANY,
|
||||||
|
data_sample=self.outputs[2],
|
||||||
|
step=0,
|
||||||
|
show=False)
|
||||||
|
mock.assert_called_with(
|
||||||
|
'color.jpg',
|
||||||
|
image=ANY,
|
||||||
|
data_sample=self.outputs[6],
|
||||||
|
step=0,
|
||||||
|
show=False)
|
||||||
|
|
||||||
|
def test_after_val_iter(self):
|
||||||
|
runner = MagicMock()
|
||||||
|
|
||||||
|
# test epoch-based
|
||||||
|
runner.train_loop = MagicMock(spec=EpochBasedTrainLoop)
|
||||||
|
runner.epoch = 5
|
||||||
|
cfg = dict(type='VisualizationHook', enable=True)
|
||||||
|
hook = HOOKS.build(cfg)
|
||||||
|
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||||
|
hook.after_val_iter(runner, 0, self.data_batch, self.outputs)
|
||||||
|
mock.assert_called_once_with(
|
||||||
|
'color.jpg',
|
||||||
|
image=ANY,
|
||||||
|
data_sample=self.outputs[0],
|
||||||
|
step=5,
|
||||||
|
show=False)
|
||||||
|
|
||||||
|
# test iter-based
|
||||||
|
runner.train_loop = MagicMock(spec=IterBasedTrainLoop)
|
||||||
|
runner.iter = 300
|
||||||
|
cfg = dict(type='VisualizationHook', enable=True)
|
||||||
|
hook = HOOKS.build(cfg)
|
||||||
|
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||||
|
hook.after_val_iter(runner, 0, self.data_batch, self.outputs)
|
||||||
|
mock.assert_called_once_with(
|
||||||
|
'color.jpg',
|
||||||
|
image=ANY,
|
||||||
|
data_sample=self.outputs[0],
|
||||||
|
step=300,
|
||||||
|
show=False)
|
||||||
|
|
||||||
|
def test_after_test_iter(self):
|
||||||
|
runner = MagicMock()
|
||||||
|
|
||||||
|
cfg = dict(type='VisualizationHook', enable=True)
|
||||||
|
hook = HOOKS.build(cfg)
|
||||||
|
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||||
|
hook.after_test_iter(runner, 0, self.data_batch, self.outputs)
|
||||||
|
mock.assert_called_once_with(
|
||||||
|
'color.jpg',
|
||||||
|
image=ANY,
|
||||||
|
data_sample=self.outputs[0],
|
||||||
|
step=0,
|
||||||
|
show=False)
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.tmpdir.cleanup()
|
|
@ -27,6 +27,23 @@ def parse_args():
|
||||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||||
'Note that the quotation marks are necessary and that no white space '
|
'Note that the quotation marks are necessary and that no white space '
|
||||||
'is allowed.')
|
'is allowed.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--show-dir',
|
||||||
|
help='directory where the visualization images will be saved.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--show',
|
||||||
|
action='store_true',
|
||||||
|
help='whether to display the prediction results in a window.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--interval',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='visualize per interval samples.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--wait-time',
|
||||||
|
type=float,
|
||||||
|
default=2,
|
||||||
|
help='display time of every window. (second)')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--launcher',
|
'--launcher',
|
||||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||||
|
@ -39,6 +56,23 @@ def parse_args():
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def merge_args(cfg, args):
|
||||||
|
"""Merge CLI arguments to config."""
|
||||||
|
# -------------------- visualization --------------------
|
||||||
|
if args.show or (args.show_dir is not None):
|
||||||
|
assert 'visualization' in cfg.default_hooks, \
|
||||||
|
'VisualizationHook is not set in the `default_hooks` field of ' \
|
||||||
|
'config. Please set `visualization=dict(type="VisualizationHook")`'
|
||||||
|
|
||||||
|
cfg.default_hooks.visualization.enable = True
|
||||||
|
cfg.default_hooks.visualization.show = args.show
|
||||||
|
cfg.default_hooks.visualization.wait_time = args.wait_time
|
||||||
|
cfg.default_hooks.visualization.out_dir = args.show_dir
|
||||||
|
cfg.default_hooks.visualization.interval = args.interval
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
|
@ -48,6 +82,7 @@ def main():
|
||||||
|
|
||||||
# load config
|
# load config
|
||||||
cfg = Config.fromfile(args.config)
|
cfg = Config.fromfile(args.config)
|
||||||
|
cfg = merge_args(cfg, args)
|
||||||
cfg.launcher = args.launcher
|
cfg.launcher = args.launcher
|
||||||
if args.cfg_options is not None:
|
if args.cfg_options is not None:
|
||||||
cfg.merge_from_dict(args.cfg_options)
|
cfg.merge_from_dict(args.cfg_options)
|
||||||
|
|
Loading…
Reference in New Issue