diff --git a/mmseg/engine/__init__.py b/mmseg/engine/__init__.py index 36a88d71f..517f811d5 100644 --- a/mmseg/engine/__init__.py +++ b/mmseg/engine/__init__.py @@ -1,7 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .hooks import SegVisualizationHook from .optimizers import (LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) +from .visualization import SegLocalVisualizer __all__ = [ - 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor' + 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor', + 'SegVisualizationHook', 'SegLocalVisualizer' ] diff --git a/mmseg/engine/hooks/__init__.py b/mmseg/engine/hooks/__init__.py new file mode 100644 index 000000000..c6048088a --- /dev/null +++ b/mmseg/engine/hooks/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .visualization_hook import SegVisualizationHook + +__all__ = ['SegVisualizationHook'] diff --git a/mmseg/engine/hooks/visualization_hook.py b/mmseg/engine/hooks/visualization_hook.py new file mode 100644 index 000000000..bd9bc2f3a --- /dev/null +++ b/mmseg/engine/hooks/visualization_hook.py @@ -0,0 +1,101 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings +from typing import Sequence + +import mmcv +from mmengine.hooks import Hook +from mmengine.runner import Runner + +from mmseg.data import SegDataSample +from mmseg.engine.visualization import SegLocalVisualizer +from mmseg.registry import HOOKS + + +@HOOKS.register_module() +class SegVisualizationHook(Hook): + """Segmentation Visualization Hook. Used to visualize validation and + testing process prediction results. + + In the testing phase: + + 1. If ``show`` is True, it means that only the prediction results are + visualized without storing data, so ``vis_backends`` needs to + be excluded. + + Args: + draw (bool): whether to draw prediction results. If it is False, + it means that no drawing will be done. Defaults to False. + interval (int): The interval of visualization. Defaults to 50. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + """ + + def __init__(self, + draw: bool = False, + interval: int = 50, + show: bool = False, + wait_time: float = 0., + file_client_args: dict = dict(backend='disk')): + self._visualizer: SegLocalVisualizer = \ + SegLocalVisualizer.get_current_instance() + self.interval = interval + self.show = show + if self.show: + # No need to think about vis backends. + self._visualizer._vis_backends = {} + warnings.warn('The show is True, it means that only ' + 'the prediction results are visualized ' + 'without storing data, so vis_backends ' + 'needs to be excluded.') + + self.wait_time = wait_time + self.file_client_args = file_client_args.copy() + self.file_client = None + self.draw = draw + if not self.draw: + warnings.warn('The draw is False, it means that the ' + 'hook for visualization will not take ' + 'effect. The results will NOT be ' + 'visualized or stored.') + + def after_iter(self, + runner: Runner, + batch_idx: int, + data_batch: Sequence[dict], + outputs: Sequence[SegDataSample], + mode: str = 'val') -> None: + """Run after every ``self.interval`` validation iterations. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (Sequence[dict]): Data from dataloader. + outputs (Sequence[:obj:`SegDataSample`]): Outputs from model. + mode (str): mode (str): Current mode of runner. Defaults to 'val'. + """ + if self.draw is False or mode == 'train': + return + + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + + if self.every_n_inner_iters(batch_idx, self.interval): + for input_data, output in zip(data_batch, outputs): + img_path = input_data['data_sample'].img_path + img_bytes = self.file_client.get(img_path) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + window_name = f'{mode}_{osp.basename(img_path)}' + + gt_sample = input_data['data_sample'] + self._visualizer.add_datasample( + window_name, + img, + gt_sample=gt_sample, + pred_sample=output, + show=self.show, + wait_time=self.wait_time, + step=runner.iter) diff --git a/mmseg/engine/visualization/__init__.py b/mmseg/engine/visualization/__init__.py new file mode 100644 index 000000000..8cbb211e5 --- /dev/null +++ b/mmseg/engine/visualization/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .local_visualizer import SegLocalVisualizer + +__all__ = ['SegLocalVisualizer'] diff --git a/mmseg/engine/visualization/local_visualizer.py b/mmseg/engine/visualization/local_visualizer.py new file mode 100644 index 000000000..ea966fa5b --- /dev/null +++ b/mmseg/engine/visualization/local_visualizer.py @@ -0,0 +1,172 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple + +import numpy as np +from mmengine import Visualizer +from mmengine.data import PixelData +from mmengine.dist import master_only + +from mmseg.data import SegDataSample +from mmseg.registry import VISUALIZERS + + +@VISUALIZERS.register_module() +class SegLocalVisualizer(Visualizer): + """MMSegmentation Local Visualizer. + + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + image (np.ndarray, optional): the origin image to draw. The format + should be RGB. Defaults to None. + vis_backends (list, optional): Visual backend config list. + Defaults to None. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + alpha (int, float): The transparency of segmentation mask. + Defaults to 0.8. + + Examples: + >>> import numpy as np + >>> import torch + >>> from mmengine.data import PixelData + >>> from mmseg.data import SegDataSample + >>> from mmseg.engine.visualization import SegLocalVisualizer + + >>> seg_local_visualizer = SegLocalVisualizer() + >>> image = np.random.randint(0, 256, + ... size=(10, 12, 3)).astype('uint8') + >>> gt_sem_seg_data = dict(data=torch.randint(0, 2, (1, 10, 12))) + >>> gt_sem_seg = PixelData(**gt_sem_seg_data) + >>> gt_seg_data_sample = SegDataSample() + >>> gt_seg_data_sample.gt_sem_seg = gt_sem_seg + >>> seg_local_visualizer.dataset_meta = dict( + >>> classes=('background', 'foreground'), + >>> palette=[[120, 120, 120], [6, 230, 230]]) + >>> seg_local_visualizer.add_datasample('visualizer_example', + ... image, gt_seg_data_sample) + >>> seg_local_visualizer.add_datasample( + ... 'visualizer_example', image, + ... gt_seg_data_sample, show=True) + """ + + def __init__(self, + name: str = 'visualizer', + image: Optional[np.ndarray] = None, + vis_backends: Optional[Dict] = None, + save_dir: Optional[str] = None, + alpha: float = 0.8, + **kwargs): + super().__init__(name, image, vis_backends, save_dir, **kwargs) + self.alpha = alpha + # Set default value. When calling + # `SegLocalVisualizer().dataset_meta=xxx`, + # it will override the default value. + self.dataset_meta = {} + + def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData, + classes: Optional[Tuple[str]], + palette: Optional[List[List[int]]]) -> np.ndarray: + """Draw semantic seg of GT or prediction. + + Args: + image (np.ndarray): The image to draw. + sem_seg (:obj:`PixelData`): Data structure for + pixel-level annotations or predictions. + classes (Tuple[str], optional): Category information. + palette (List[List[int]], optional): The palette of + segmentation map. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + num_classes = len(classes) + + sem_seg = sem_seg.data + ids = np.unique(sem_seg)[::-1] + legal_indices = ids < num_classes + ids = ids[legal_indices] + labels = np.array(ids, dtype=np.int64) + + colors = [palette[label] for label in labels] + + self.set_image(image) + + # draw semantic masks + for label, color in zip(labels, colors): + self.draw_binary_masks( + sem_seg == label, colors=[color], alphas=self.alpha) + + return self.get_image() + + @master_only + def add_datasample(self, + name: str, + image: np.ndarray, + gt_sample: Optional[SegDataSample] = None, + pred_sample: Optional[SegDataSample] = None, + draw_gt: bool = True, + draw_pred: bool = True, + show: bool = False, + wait_time: float = 0, + step: int = 0) -> None: + """Draw datasample and save to all backends. + + - If GT and prediction are plotted at the same time, they are + displayed in a stitched image where the left image is the + ground truth and the right image is the prediction. + - If ``show`` is True, all storage backends are ignored, and + the images will be displayed in a local window. + + Args: + name (str): The image identifier. + image (np.ndarray): The image to draw. + gt_sample (:obj:`SegDataSample`, optional): GT SegDataSample. + Defaults to None. + pred_sample (:obj:`SegDataSample`, optional): Prediction + SegDataSample. Defaults to None. + draw_gt (bool): Whether to draw GT SegDataSample. Default to True. + draw_pred (bool): Whether to draw Prediction SegDataSample. + Defaults to True. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + step (int): Global step value to record. Defaults to 0. + """ + classes = self.dataset_meta.get('classes', None) + palette = self.dataset_meta.get('palette', None) + + gt_img_data = None + pred_img_data = None + + if draw_gt and gt_sample is not None: + gt_img_data = image + if 'gt_sem_seg' in gt_sample: + assert classes is not None, 'class information is ' \ + 'not provided when ' \ + 'visualizing semantic ' \ + 'segmentation results.' + gt_img_data = self._draw_sem_seg(gt_img_data, + gt_sample.gt_sem_seg, classes, + palette) + + if draw_pred and pred_sample is not None: + pred_img_data = image + if 'pred_sem_seg' in pred_sample: + assert classes is not None, 'class information is ' \ + 'not provided when ' \ + 'visualizing semantic ' \ + 'segmentation results.' + pred_img_data = self._draw_sem_seg(pred_img_data, + pred_sample.pred_sem_seg, + classes, palette) + + if gt_img_data is not None and pred_img_data is not None: + drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) + elif gt_img_data is not None: + drawn_img = gt_img_data + else: + drawn_img = pred_img_data + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + else: + self.add_image(name, drawn_img, step) diff --git a/tests/test_engine/test_local_visualizer.py b/tests/test_engine/test_local_visualizer.py new file mode 100644 index 000000000..6100fe856 --- /dev/null +++ b/tests/test_engine/test_local_visualizer.py @@ -0,0 +1,187 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +from unittest import TestCase + +import cv2 +import mmcv +import numpy as np +import torch +from mmengine.data import PixelData + +from mmseg.data import SegDataSample +from mmseg.engine.visualization import SegLocalVisualizer + + +class TestSegLocalVisualizer(TestCase): + + def test_add_datasample(self): + h = 10 + w = 12 + num_class = 2 + out_file = 'out_file' + + image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8') + + # test gt_sem_seg + gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) + gt_sem_seg = PixelData(**gt_sem_seg_data) + + gt_seg_data_sample = SegDataSample() + gt_seg_data_sample.gt_sem_seg = gt_sem_seg + + seg_local_visualizer = SegLocalVisualizer( + vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir') + seg_local_visualizer.dataset_meta = dict( + classes=('background', 'foreground'), + palette=[[120, 120, 120], [6, 230, 230]]) + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample) + + # test out_file + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample) + + assert os.path.exists( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) + drawn_img = cv2.imread( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) + assert drawn_img.shape == (h, w, 3) + + os.remove( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) + os.rmdir('temp_dir' + '/vis_data/vis_image') + + # test gt_instances and pred_instances + pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) + pred_sem_seg = PixelData(**pred_sem_seg_data) + + pred_seg_data_sample = SegDataSample() + pred_seg_data_sample.pred_sem_seg = pred_sem_seg + + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample, + pred_seg_data_sample) + self._assert_image_and_shape( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), + (h, w * 2, 3)) + + seg_local_visualizer.add_datasample( + out_file, + image, + gt_seg_data_sample, + pred_seg_data_sample, + draw_gt=False) + self._assert_image_and_shape( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), + (h, w, 3)) + + seg_local_visualizer.add_datasample( + out_file, + image, + gt_seg_data_sample, + pred_seg_data_sample, + draw_pred=False) + self._assert_image_and_shape( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), + (h, w, 3)) + os.rmdir('temp_dir/vis_data') + os.rmdir('temp_dir') + + def test_cityscapes_add_datasample(self): + h = 128 + w = 256 + num_class = 19 + out_file = 'out_file_cityscapes' + + image = mmcv.imread( + osp.join( + osp.dirname(__file__), + '../data/pseudo_cityscapes_dataset/leftImg8bit/frankfurt_000000_000294_leftImg8bit.png' # noqa + ), + 'color') + sem_seg = mmcv.imread( + osp.join( + osp.dirname(__file__), + '../data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_labelTrainIds.png' # noqa + ), + 'unchanged') + sem_seg = torch.unsqueeze(torch.from_numpy(sem_seg), 0) + gt_sem_seg_data = dict(data=sem_seg) + gt_sem_seg = PixelData(**gt_sem_seg_data) + + gt_seg_data_sample = SegDataSample() + gt_seg_data_sample.gt_sem_seg = gt_sem_seg + + seg_local_visualizer = SegLocalVisualizer( + vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir') + seg_local_visualizer.dataset_meta = dict( + classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', + 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', + 'motorcycle', 'bicycle'), + palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], + [102, 102, 156], [190, 153, 153], [153, 153, 153], + [250, 170, 30], [220, 220, 0], [107, 142, 35], + [152, 251, 152], [70, 130, 180], [220, 20, 60], + [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], + [0, 80, 100], [0, 0, 230], [119, 11, 32]]) + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample) + + # test out_file + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample) + assert os.path.exists( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) + drawn_img = cv2.imread( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) + assert drawn_img.shape == (h, w, 3) + + os.remove( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) + os.rmdir('temp_dir/vis_data/vis_image') + + # test gt_instances and pred_instances + pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) + pred_sem_seg = PixelData(**pred_sem_seg_data) + + pred_seg_data_sample = SegDataSample() + pred_seg_data_sample.pred_sem_seg = pred_sem_seg + + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample, + pred_seg_data_sample) + self._assert_image_and_shape( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), + (h, w * 2, 3)) + + seg_local_visualizer.add_datasample( + out_file, + image, + gt_seg_data_sample, + pred_seg_data_sample, + draw_gt=False) + self._assert_image_and_shape( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), + (h, w, 3)) + + seg_local_visualizer.add_datasample( + out_file, + image, + gt_seg_data_sample, + pred_seg_data_sample, + draw_pred=False) + self._assert_image_and_shape( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), + (h, w, 3)) + + os.rmdir('temp_dir/vis_data') + os.rmdir('temp_dir') + + def _assert_image_and_shape(self, out_file, out_shape): + assert os.path.exists(out_file) + drawn_img = cv2.imread(out_file) + assert drawn_img.shape == out_shape + os.remove(out_file) + os.rmdir('temp_dir/vis_data/vis_image') diff --git a/tests/test_engine/test_visualization_hook.py b/tests/test_engine/test_visualization_hook.py new file mode 100644 index 000000000..a70fb612e --- /dev/null +++ b/tests/test_engine/test_visualization_hook.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase +from unittest.mock import Mock + +import torch +from mmengine.data import PixelData + +from mmseg.data import SegDataSample +from mmseg.engine.hooks import SegVisualizationHook +from mmseg.engine.visualization import SegLocalVisualizer + + +class TestVisualizationHook(TestCase): + + def setUp(self) -> None: + + h = 288 + w = 512 + num_class = 2 + + SegLocalVisualizer.get_instance('visualizer') + SegLocalVisualizer.dataset_meta = dict( + classes=('background', 'foreground'), + palette=[[120, 120, 120], [6, 230, 230]]) + + data_sample = SegDataSample() + data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'}) + self.data_batch = [{'data_sample': data_sample}] * 2 + + pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) + pred_sem_seg = PixelData(**pred_sem_seg_data) + pred_seg_data_sample = SegDataSample() + pred_seg_data_sample.pred_sem_seg = pred_sem_seg + self.outputs = [pred_seg_data_sample] * 2 + + def test_after_iter(self): + runner = Mock() + runner.iter = 1 + hook = SegVisualizationHook(draw=True, interval=1) + hook._after_iter( + runner, 1, self.data_batch, self.outputs, mode='train') + hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='val') + hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='test') + + def test_after_val_iter(self): + runner = Mock() + runner.iter = 2 + hook = SegVisualizationHook(interval=1) + hook.after_val_iter(runner, 1, self.data_batch, self.outputs) + + hook = SegVisualizationHook(draw=True, interval=1) + hook.after_val_iter(runner, 1, self.data_batch, self.outputs) + + hook = SegVisualizationHook( + draw=True, interval=1, show=True, wait_time=1) + hook.after_val_iter(runner, 1, self.data_batch, self.outputs) + + def test_after_test_iter(self): + runner = Mock() + runner.iter = 3 + hook = SegVisualizationHook(draw=True, interval=1) + hook.after_iter(runner, 1, self.data_batch, self.outputs) diff --git a/tools/browse_dataset.py b/tools/browse_dataset.py index 4a62a8a49..b5e0dc978 100644 --- a/tools/browse_dataset.py +++ b/tools/browse_dataset.py @@ -1,48 +1,29 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse -import os -import warnings -from pathlib import Path +import os.path as osp import mmcv -import numpy as np from mmcv import Config, DictAction -from mmseg.registry import DATASETS +from mmseg.datasets import DATASETS +from mmseg.registry import VISUALIZERS +from mmseg.utils import register_all_modules def parse_args(): parser = argparse.ArgumentParser(description='Browse a dataset') parser.add_argument('config', help='train config file path') - parser.add_argument( - '--show-origin', - default=False, - action='store_true', - help='if True, omit all augmentation in pipeline,' - ' show origin image and seg map') - parser.add_argument( - '--skip-type', - type=str, - nargs='+', - default=['DefaultFormatBundle', 'Normalize', 'Collect'], - help='skip some useless pipeline,if `show-origin` is true, ' - 'all pipeline except `Load` will be skipped') parser.add_argument( '--output-dir', - default='./output', + default=None, type=str, help='If there is no display interface, you can save it') - parser.add_argument('--show', default=False, action='store_true') + parser.add_argument('--not-show', default=False, action='store_true') parser.add_argument( '--show-interval', - type=int, - default=999, - help='the interval of show (ms)') - parser.add_argument( - '--opacity', type=float, - default=0.5, - help='the opacity of semantic map') + default=2, + help='the interval of show (s)') parser.add_argument( '--cfg-options', nargs='+', @@ -57,124 +38,35 @@ def parse_args(): return args -def imshow_semantic(img, - seg, - class_names, - palette=None, - win_name='', - show=False, - wait_time=0, - out_file=None, - opacity=0.5): - """Draw `result` over `img`. - - Args: - img (str or Tensor): The image to be displayed. - seg (Tensor): The semantic segmentation results to draw over - `img`. - class_names (list[str]): Names of each classes. - palette (list[list[int]]] | np.ndarray | None): The palette of - segmentation map. If None is given, random palette will be - generated. Default: None - win_name (str): The window name. - wait_time (int): Value of waitKey param. - Default: 0. - show (bool): Whether to show the image. - Default: False. - out_file (str or None): The filename to write the image. - Default: None. - opacity(float): Opacity of painted segmentation map. - Default 0.5. - Must be in (0, 1] range. - Returns: - img (Tensor): Only if not `show` or `out_file` - """ - img = mmcv.imread(img) - img = img.copy() - if palette is None: - palette = np.random.randint(0, 255, size=(len(class_names), 3)) - palette = np.array(palette) - assert palette.shape[0] == len(class_names) - assert palette.shape[1] == 3 - assert len(palette.shape) == 2 - assert 0 < opacity <= 1.0 - color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) - for label, color in enumerate(palette): - color_seg[seg == label, :] = color - # convert to BGR - color_seg = color_seg[..., ::-1] - - img = img * (1 - opacity) + color_seg * opacity - img = img.astype(np.uint8) - # if out_file specified, do not show image in window - if out_file is not None: - show = False - - if show: - mmcv.imshow(img, win_name, wait_time) - if out_file is not None: - mmcv.imwrite(img, out_file) - - if not (show or out_file): - warnings.warn('show==False and out_file is not specified, only ' - 'result image will be returned') - return img - - -def _retrieve_data_cfg(_data_cfg, skip_type, show_origin): - if show_origin is True: - # only keep pipeline of Loading data and ann - _data_cfg['pipeline'] = [ - x for x in _data_cfg.pipeline if 'Load' in x['type'] - ] - else: - _data_cfg['pipeline'] = [ - x for x in _data_cfg.pipeline if x['type'] not in skip_type - ] - - -def retrieve_data_cfg(config_path, skip_type, cfg_options, show_origin=False): - cfg = Config.fromfile(config_path) - if cfg_options is not None: - cfg.merge_from_dict(cfg_options) - train_data_cfg = cfg.data.train - if isinstance(train_data_cfg, list): - for _data_cfg in train_data_cfg: - while 'dataset' in _data_cfg and _data_cfg[ - 'type'] != 'MultiImageMixDataset': - _data_cfg = _data_cfg['dataset'] - if 'pipeline' in _data_cfg: - _retrieve_data_cfg(_data_cfg, skip_type, show_origin) - else: - raise ValueError - else: - while 'dataset' in train_data_cfg and train_data_cfg[ - 'type'] != 'MultiImageMixDataset': - train_data_cfg = train_data_cfg['dataset'] - _retrieve_data_cfg(train_data_cfg, skip_type, show_origin) - return cfg - - def main(): args = parse_args() - cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options, - args.show_origin) - dataset = DATASETS.build(cfg.data.train) + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # register all modules in mmseg into the registries + register_all_modules() + + dataset = DATASETS.build(cfg.train_dataloader.dataset) + + visualizer = VISUALIZERS.build(cfg.visualizer) + visualizer.dataset_meta = dataset.METAINFO + progress_bar = mmcv.ProgressBar(len(dataset)) for item in dataset: - filename = os.path.join(args.output_dir, - Path(item['filename']).name - ) if args.output_dir is not None else None - imshow_semantic( - item['img'], - item['gt_semantic_seg'], - dataset.CLASSES, - dataset.PALETTE, - show=args.show, - wait_time=args.show_interval, - out_file=filename, - opacity=args.opacity, - ) + img = item['inputs'].permute(1, 2, 0).numpy() + data_sample = item['data_sample'].numpy() + img_path = osp.basename(item['data_sample'].img_path) + + img = img[..., [2, 1, 0]] # bgr to rgb + + visualizer.add_datasample( + osp.basename(img_path), + img, + data_sample, + show=not args.not_show, + wait_time=args.show_interval) + progress_bar.update() diff --git a/tools/test.py b/tools/test.py index e4e1b5d4d..59f7c7095 100644 --- a/tools/test.py +++ b/tools/test.py @@ -19,6 +19,15 @@ def parse_args(): '--work-dir', help=('if specified, the evaluation metric results will be dumped' 'into the directory as json')) + parser.add_argument( + '--show', action='store_true', help='show prediction results') + parser.add_argument( + '--show-dir', + help='directory where painted images will be saved. ' + 'If specified, it will be automatically saved ' + 'to the work_dir/timestamp/show_dir') + parser.add_argument( + '--wait-time', type=float, default=2, help='the interval of show (s)') parser.add_argument( '--cfg-options', nargs='+', @@ -42,6 +51,26 @@ def parse_args(): return args +def trigger_visualization_hook(cfg, args): + default_hooks = cfg.default_hooks + if 'visualization' in default_hooks: + visualization_hook = default_hooks['visualization'] + # Turn on visualization + visualization_hook['draw'] = True + if args.show: + visualization_hook['show'] = True + visualization_hook['wait_time'] = args.wait_time + if args.show_dir: + visualization_hook['test_out_dir'] = args.show_dir + else: + raise RuntimeError( + 'VisualizationHook must be included in default_hooks.' + 'refer to usage ' + '"visualization=dict(type=\'VisualizationHook\')"') + + return cfg + + def main(): args = parse_args() @@ -66,6 +95,9 @@ def main(): cfg.load_from = args.checkpoint + if args.show or args.show_dir: + cfg = trigger_visualization_hook(cfg, args) + # build the runner from config runner = Runner.from_cfg(cfg)