diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py index 4cfe0ffe..8937bf3a 100644 --- a/configs/_base_/default_runtime.py +++ b/configs/_base_/default_runtime.py @@ -17,6 +17,9 @@ default_hooks = dict( # set sampler seed in distributed evrionment. sampler_seed=dict(type='DistSamplerSeedHook'), + + # validation results visualization, set True to enable it. + visualization=dict(type='VisualizationHook', enable=False), ) # configure environment diff --git a/mmcls/core/hook/__init__.py b/mmcls/core/hook/__init__.py index 4212dcf9..42b17c1c 100644 --- a/mmcls/core/hook/__init__.py +++ b/mmcls/core/hook/__init__.py @@ -2,9 +2,11 @@ from .class_num_check_hook import ClassNumCheckHook from .lr_updater import CosineAnnealingCooldownLrUpdaterHook from .precise_bn_hook import PreciseBNHook -from .wandblogger_hook import MMClsWandbHook +from .visualization_hook import VisualizationHook __all__ = [ - 'ClassNumCheckHook', 'PreciseBNHook', - 'CosineAnnealingCooldownLrUpdaterHook', 'MMClsWandbHook' + 'ClassNumCheckHook', + 'PreciseBNHook', + 'CosineAnnealingCooldownLrUpdaterHook', + 'VisualizationHook', ] diff --git a/mmcls/core/hook/visualization_hook.py b/mmcls/core/hook/visualization_hook.py new file mode 100644 index 00000000..c15ad92f --- /dev/null +++ b/mmcls/core/hook/visualization_hook.py @@ -0,0 +1,131 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import os.path as osp +from typing import Optional, Sequence + +from mmengine import FileClient +from mmengine.hooks import Hook +from mmengine.runner import EpochBasedTrainLoop, Runner +from mmengine.visualization import Visualizer + +from mmcls.core import ClsDataSample +from mmcls.registry import HOOKS + + +@HOOKS.register_module() +class VisualizationHook(Hook): + """Classification Visualization Hook. Used to visualize validation and + testing prediction results. + + - If ``out_dir`` is specified, all storage backends are ignored + and save the image to the ``out_dir``. + - If ``show`` is True, plot the result image in a window, please + confirm you are able to access the graphical interface. + + Args: + enable (bool): Whether to enable this hook. Defaults to False. + interval (int): The interval of samples to visualize. Defaults to 5000. + show (bool): Whether to display the drawn image. Default to False. + out_dir (str, optional): directory where painted images will be saved + in the testing process. If None, handle with the backends of the + visualizer. Defaults to None. + **kwargs: other keyword arguments of + :meth:`mmcls.core.ClsVisualizer.add_datasample`. + """ + + def __init__(self, + enable=False, + interval: int = 5000, + show: bool = False, + out_dir: Optional[str] = None, + **kwargs): + self._visualizer: Visualizer = Visualizer.get_current_instance() + + self.enable = enable + self.interval = interval + self.show = show + self.out_dir = out_dir + if out_dir is not None: + self.file_client = FileClient.infer_client(uri=out_dir) + else: + self.file_client = None + + self.draw_args = {**kwargs, 'show': show} + + def _draw_samples(self, + batch_idx: int, + data_batch: Sequence[dict], + outputs: Sequence[ClsDataSample], + step: int = 0) -> None: + """Visualize every ``self.interval`` samples from a data batch. + + Args: + batch_idx (int): The index of the current batch in the val loop. + data_batch (Sequence[dict]): Data from dataloader. + outputs (Sequence[:obj:`DetDataSample`]): Outputs from model. + step (int): Global step value to record. Default to 0. + """ + if self.enable is False: + return + + batch_size = len(outputs) + start_idx = batch_size * batch_idx + end_idx = start_idx + batch_size + + # The first index divisible by the interval, after the start index + first_sample_id = math.ceil(start_idx / self.interval) * self.interval + + for sample_id in range(first_sample_id, end_idx, self.interval): + image = data_batch[sample_id - start_idx]['inputs'] + image = image.permute(1, 2, 0).numpy().astype('uint8') + + data_sample = outputs[sample_id - start_idx] + if 'img_path' in data_sample: + # osp.basename works on different platforms even file clients. + sample_name = osp.basename(data_sample.get('img_path')) + else: + sample_name = str(sample_id) + + draw_args = self.draw_args + if self.out_dir is not None: + draw_args['out_file'] = self.file_client.join_path( + self.out_dir, f'{sample_name}_{step}.png') + + self._visualizer.add_datasample( + sample_name, + image=image, + data_sample=data_sample, + step=step, + **self.draw_args, + ) + + def after_val_iter(self, runner: Runner, batch_idx: int, + data_batch: Sequence[dict], + outputs: Sequence[ClsDataSample]) -> None: + """Visualize every ``self.interval`` samples during validation. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (Sequence[dict]): Data from dataloader. + outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model. + """ + if isinstance(runner.train_loop, EpochBasedTrainLoop): + step = runner.epoch + else: + step = runner.iter + + self._draw_samples(batch_idx, data_batch, outputs, step=step) + + def after_test_iter(self, runner: Runner, batch_idx: int, + data_batch: Sequence[dict], + outputs: Sequence[ClsDataSample]) -> None: + """Visualize every ``self.interval`` samples during test. + + Args: + runner (:obj:`Runner`): The runner of the testing process. + batch_idx (int): The index of the current batch in the test loop. + data_batch (Sequence[dict]): Data from dataloader. + outputs (Sequence[:obj:`DetDataSample`]): Outputs from model. + """ + self._draw_samples(batch_idx, data_batch, outputs, step=0) diff --git a/mmcls/core/visualization/cls_visualizer.py b/mmcls/core/visualization/cls_visualizer.py index 1b3fb177..5a6bad30 100644 --- a/mmcls/core/visualization/cls_visualizer.py +++ b/mmcls/core/visualization/cls_visualizer.py @@ -95,11 +95,10 @@ class ClsVisualizer(Visualizer): step: int = 0) -> None: """Draw datasample and save to all backends. - - If ``show`` is True, all storage backends are ignored and then - displayed in a local window. - - If the ``out_file`` parameter is specified, the drawn image - will be additionally saved to ``out_file``. It is usually used - in script mode like ``image_demo.py`` + - If ``out_file`` is specified, all storage backends are ignored + and save the image to the ``out_file``. + - If ``show`` is True, plot the result image in a window, please + confirm you are able to access the graphical interface. Args: name (str): The image identifier. @@ -121,8 +120,9 @@ class ClsVisualizer(Visualizer): wait_time (float): The interval of show (s). Default to 0, which means "forever". out_file (str, optional): Extra path to save the visualization - result. Whether specified or not, the visualizer will still - save the results by its storage backends. Default to None. + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Default to None. step (int): Global step value to record. Default to 0. """ classes = None @@ -179,8 +179,9 @@ class ClsVisualizer(Visualizer): if show: self.show(drawn_img, win_name=name, wait_time=wait_time) - else: - self.add_image(name, drawn_img, step=step) if out_file is not None: + # save the image to the target file instead of vis_backends mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) diff --git a/tests/test_core/test_hooks/test_visualization_hook.py b/tests/test_core/test_hooks/test_visualization_hook.py new file mode 100644 index 00000000..023fc24f --- /dev/null +++ b/tests/test_core/test_hooks/test_visualization_hook.py @@ -0,0 +1,152 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import tempfile +from unittest import TestCase +from unittest.mock import ANY, MagicMock, patch + +import torch +from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop + +from mmcls.core import ClsDataSample, ClsVisualizer, VisualizationHook +from mmcls.registry import HOOKS +from mmcls.utils import register_all_modules + +register_all_modules() + + +class TestVisualizationHook(TestCase): + + def setUp(self) -> None: + ClsVisualizer.get_instance('visualizer') + + data_sample = ClsDataSample().set_gt_label(1).set_pred_label(2) + data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'}) + self.data_batch = [{ + 'inputs': torch.randint(0, 256, (3, 224, 224)), + 'data_sample': data_sample + }] * 10 + + self.outputs = [data_sample] * 10 + + self.tmpdir = tempfile.TemporaryDirectory() + + def test_initialize(self): + # test file_client + cfg = dict(type='VisualizationHook') + hook = HOOKS.build(cfg) + self.assertIsNone(hook.file_client) + + cfg = dict(type='VisualizationHook', out_dir=self.tmpdir.name) + hook = HOOKS.build(cfg) + self.assertIsNotNone(hook.file_client) + + # test draw_args + + def test_draw_samples(self): + # test enable=False + cfg = dict(type='VisualizationHook', enable=False) + hook: VisualizationHook = HOOKS.build(cfg) + with patch.object(hook._visualizer, 'add_datasample') as mock: + hook._draw_samples(1, self.data_batch, self.outputs, step=1) + mock.assert_not_called() + + # test enable=True + cfg = dict(type='VisualizationHook', enable=True, show=True) + hook: VisualizationHook = HOOKS.build(cfg) + with patch.object(hook._visualizer, 'add_datasample') as mock: + hook._draw_samples(0, self.data_batch, self.outputs, step=0) + mock.assert_called_once_with( + 'color.jpg', + image=ANY, + data_sample=self.outputs[0], + step=0, + show=True) + + # test samples without path + cfg = dict(type='VisualizationHook', enable=True) + hook: VisualizationHook = HOOKS.build(cfg) + with patch.object(hook._visualizer, 'add_datasample') as mock: + outputs = [ClsDataSample()] * 10 + hook._draw_samples(0, self.data_batch, outputs, step=0) + mock.assert_called_once_with( + '0', image=ANY, data_sample=outputs[0], step=0, show=False) + + # test out_dir + cfg = dict( + type='VisualizationHook', enable=True, out_dir=self.tmpdir.name) + hook: VisualizationHook = HOOKS.build(cfg) + with patch.object(hook._visualizer, 'add_datasample') as mock: + hook._draw_samples(0, self.data_batch, self.outputs, step=0) + mock.assert_called_once_with( + 'color.jpg', + image=ANY, + data_sample=self.outputs[0], + step=0, + show=False, + out_file=osp.join(self.tmpdir.name, 'color.jpg_0.png')) + + # test sample idx + cfg = dict(type='VisualizationHook', enable=True, interval=4) + hook: VisualizationHook = HOOKS.build(cfg) + with patch.object(hook._visualizer, 'add_datasample') as mock: + hook._draw_samples(1, self.data_batch, self.outputs, step=0) + mock.assert_called_with( + 'color.jpg', + image=ANY, + data_sample=self.outputs[2], + step=0, + show=False) + mock.assert_called_with( + 'color.jpg', + image=ANY, + data_sample=self.outputs[6], + step=0, + show=False) + + def test_after_val_iter(self): + runner = MagicMock() + + # test epoch-based + runner.train_loop = MagicMock(spec=EpochBasedTrainLoop) + runner.epoch = 5 + cfg = dict(type='VisualizationHook', enable=True) + hook = HOOKS.build(cfg) + with patch.object(hook._visualizer, 'add_datasample') as mock: + hook.after_val_iter(runner, 0, self.data_batch, self.outputs) + mock.assert_called_once_with( + 'color.jpg', + image=ANY, + data_sample=self.outputs[0], + step=5, + show=False) + + # test iter-based + runner.train_loop = MagicMock(spec=IterBasedTrainLoop) + runner.iter = 300 + cfg = dict(type='VisualizationHook', enable=True) + hook = HOOKS.build(cfg) + with patch.object(hook._visualizer, 'add_datasample') as mock: + hook.after_val_iter(runner, 0, self.data_batch, self.outputs) + mock.assert_called_once_with( + 'color.jpg', + image=ANY, + data_sample=self.outputs[0], + step=300, + show=False) + + def test_after_test_iter(self): + runner = MagicMock() + + cfg = dict(type='VisualizationHook', enable=True) + hook = HOOKS.build(cfg) + with patch.object(hook._visualizer, 'add_datasample') as mock: + hook.after_test_iter(runner, 0, self.data_batch, self.outputs) + mock.assert_called_once_with( + 'color.jpg', + image=ANY, + data_sample=self.outputs[0], + step=0, + show=False) + + def tearDown(self) -> None: + self.tmpdir.cleanup() diff --git a/tools/test.py b/tools/test.py index 5b7d484c..65f1debe 100644 --- a/tools/test.py +++ b/tools/test.py @@ -27,6 +27,23 @@ def parse_args(): 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') + parser.add_argument( + '--show-dir', + help='directory where the visualization images will be saved.') + parser.add_argument( + '--show', + action='store_true', + help='whether to display the prediction results in a window.') + parser.add_argument( + '--interval', + type=int, + default=1, + help='visualize per interval samples.') + parser.add_argument( + '--wait-time', + type=float, + default=2, + help='display time of every window. (second)') parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], @@ -39,6 +56,23 @@ def parse_args(): return args +def merge_args(cfg, args): + """Merge CLI arguments to config.""" + # -------------------- visualization -------------------- + if args.show or (args.show_dir is not None): + assert 'visualization' in cfg.default_hooks, \ + 'VisualizationHook is not set in the `default_hooks` field of ' \ + 'config. Please set `visualization=dict(type="VisualizationHook")`' + + cfg.default_hooks.visualization.enable = True + cfg.default_hooks.visualization.show = args.show + cfg.default_hooks.visualization.wait_time = args.wait_time + cfg.default_hooks.visualization.out_dir = args.show_dir + cfg.default_hooks.visualization.interval = args.interval + + return cfg + + def main(): args = parse_args() @@ -48,6 +82,7 @@ def main(): # load config cfg = Config.fromfile(args.config) + cfg = merge_args(cfg, args) cfg.launcher = args.launcher if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options)