[Feature] Add SegVisualizer (#1792)

* [Feature] Add SegVisualizer

* change name to visualizer_example

* fix typo

* refactor folder structure
This commit is contained in:
MengzhangLI 2022-07-27 16:28:09 +08:00 committed by GitHub
parent 4079d6dfed
commit 6873f9ece8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 599 additions and 142 deletions

View File

@ -1,7 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .hooks import SegVisualizationHook
from .optimizers import (LayerDecayOptimizerConstructor,
LearningRateDecayOptimizerConstructor)
from .visualization import SegLocalVisualizer
__all__ = [
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor'
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
'SegVisualizationHook', 'SegLocalVisualizer'
]

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .visualization_hook import SegVisualizationHook
__all__ = ['SegVisualizationHook']

View File

@ -0,0 +1,101 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Sequence
import mmcv
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmseg.data import SegDataSample
from mmseg.engine.visualization import SegLocalVisualizer
from mmseg.registry import HOOKS
@HOOKS.register_module()
class SegVisualizationHook(Hook):
"""Segmentation Visualization Hook. Used to visualize validation and
testing process prediction results.
In the testing phase:
1. If ``show`` is True, it means that only the prediction results are
visualized without storing data, so ``vis_backends`` needs to
be excluded.
Args:
draw (bool): whether to draw prediction results. If it is False,
it means that no drawing will be done. Defaults to False.
interval (int): The interval of visualization. Defaults to 50.
show (bool): Whether to display the drawn image. Default to False.
wait_time (float): The interval of show (s). Defaults to 0.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
"""
def __init__(self,
draw: bool = False,
interval: int = 50,
show: bool = False,
wait_time: float = 0.,
file_client_args: dict = dict(backend='disk')):
self._visualizer: SegLocalVisualizer = \
SegLocalVisualizer.get_current_instance()
self.interval = interval
self.show = show
if self.show:
# No need to think about vis backends.
self._visualizer._vis_backends = {}
warnings.warn('The show is True, it means that only '
'the prediction results are visualized '
'without storing data, so vis_backends '
'needs to be excluded.')
self.wait_time = wait_time
self.file_client_args = file_client_args.copy()
self.file_client = None
self.draw = draw
if not self.draw:
warnings.warn('The draw is False, it means that the '
'hook for visualization will not take '
'effect. The results will NOT be '
'visualized or stored.')
def after_iter(self,
runner: Runner,
batch_idx: int,
data_batch: Sequence[dict],
outputs: Sequence[SegDataSample],
mode: str = 'val') -> None:
"""Run after every ``self.interval`` validation iterations.
Args:
runner (:obj:`Runner`): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (Sequence[dict]): Data from dataloader.
outputs (Sequence[:obj:`SegDataSample`]): Outputs from model.
mode (str): mode (str): Current mode of runner. Defaults to 'val'.
"""
if self.draw is False or mode == 'train':
return
if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args)
if self.every_n_inner_iters(batch_idx, self.interval):
for input_data, output in zip(data_batch, outputs):
img_path = input_data['data_sample'].img_path
img_bytes = self.file_client.get(img_path)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
window_name = f'{mode}_{osp.basename(img_path)}'
gt_sample = input_data['data_sample']
self._visualizer.add_datasample(
window_name,
img,
gt_sample=gt_sample,
pred_sample=output,
show=self.show,
wait_time=self.wait_time,
step=runner.iter)

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .local_visualizer import SegLocalVisualizer
__all__ = ['SegLocalVisualizer']

View File

@ -0,0 +1,172 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple
import numpy as np
from mmengine import Visualizer
from mmengine.data import PixelData
from mmengine.dist import master_only
from mmseg.data import SegDataSample
from mmseg.registry import VISUALIZERS
@VISUALIZERS.register_module()
class SegLocalVisualizer(Visualizer):
"""MMSegmentation Local Visualizer.
Args:
name (str): Name of the instance. Defaults to 'visualizer'.
image (np.ndarray, optional): the origin image to draw. The format
should be RGB. Defaults to None.
vis_backends (list, optional): Visual backend config list.
Defaults to None.
save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data.
alpha (int, float): The transparency of segmentation mask.
Defaults to 0.8.
Examples:
>>> import numpy as np
>>> import torch
>>> from mmengine.data import PixelData
>>> from mmseg.data import SegDataSample
>>> from mmseg.engine.visualization import SegLocalVisualizer
>>> seg_local_visualizer = SegLocalVisualizer()
>>> image = np.random.randint(0, 256,
... size=(10, 12, 3)).astype('uint8')
>>> gt_sem_seg_data = dict(data=torch.randint(0, 2, (1, 10, 12)))
>>> gt_sem_seg = PixelData(**gt_sem_seg_data)
>>> gt_seg_data_sample = SegDataSample()
>>> gt_seg_data_sample.gt_sem_seg = gt_sem_seg
>>> seg_local_visualizer.dataset_meta = dict(
>>> classes=('background', 'foreground'),
>>> palette=[[120, 120, 120], [6, 230, 230]])
>>> seg_local_visualizer.add_datasample('visualizer_example',
... image, gt_seg_data_sample)
>>> seg_local_visualizer.add_datasample(
... 'visualizer_example', image,
... gt_seg_data_sample, show=True)
"""
def __init__(self,
name: str = 'visualizer',
image: Optional[np.ndarray] = None,
vis_backends: Optional[Dict] = None,
save_dir: Optional[str] = None,
alpha: float = 0.8,
**kwargs):
super().__init__(name, image, vis_backends, save_dir, **kwargs)
self.alpha = alpha
# Set default value. When calling
# `SegLocalVisualizer().dataset_meta=xxx`,
# it will override the default value.
self.dataset_meta = {}
def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
classes: Optional[Tuple[str]],
palette: Optional[List[List[int]]]) -> np.ndarray:
"""Draw semantic seg of GT or prediction.
Args:
image (np.ndarray): The image to draw.
sem_seg (:obj:`PixelData`): Data structure for
pixel-level annotations or predictions.
classes (Tuple[str], optional): Category information.
palette (List[List[int]], optional): The palette of
segmentation map.
Returns:
np.ndarray: the drawn image which channel is RGB.
"""
num_classes = len(classes)
sem_seg = sem_seg.data
ids = np.unique(sem_seg)[::-1]
legal_indices = ids < num_classes
ids = ids[legal_indices]
labels = np.array(ids, dtype=np.int64)
colors = [palette[label] for label in labels]
self.set_image(image)
# draw semantic masks
for label, color in zip(labels, colors):
self.draw_binary_masks(
sem_seg == label, colors=[color], alphas=self.alpha)
return self.get_image()
@master_only
def add_datasample(self,
name: str,
image: np.ndarray,
gt_sample: Optional[SegDataSample] = None,
pred_sample: Optional[SegDataSample] = None,
draw_gt: bool = True,
draw_pred: bool = True,
show: bool = False,
wait_time: float = 0,
step: int = 0) -> None:
"""Draw datasample and save to all backends.
- If GT and prediction are plotted at the same time, they are
displayed in a stitched image where the left image is the
ground truth and the right image is the prediction.
- If ``show`` is True, all storage backends are ignored, and
the images will be displayed in a local window.
Args:
name (str): The image identifier.
image (np.ndarray): The image to draw.
gt_sample (:obj:`SegDataSample`, optional): GT SegDataSample.
Defaults to None.
pred_sample (:obj:`SegDataSample`, optional): Prediction
SegDataSample. Defaults to None.
draw_gt (bool): Whether to draw GT SegDataSample. Default to True.
draw_pred (bool): Whether to draw Prediction SegDataSample.
Defaults to True.
show (bool): Whether to display the drawn image. Default to False.
wait_time (float): The interval of show (s). Defaults to 0.
step (int): Global step value to record. Defaults to 0.
"""
classes = self.dataset_meta.get('classes', None)
palette = self.dataset_meta.get('palette', None)
gt_img_data = None
pred_img_data = None
if draw_gt and gt_sample is not None:
gt_img_data = image
if 'gt_sem_seg' in gt_sample:
assert classes is not None, 'class information is ' \
'not provided when ' \
'visualizing semantic ' \
'segmentation results.'
gt_img_data = self._draw_sem_seg(gt_img_data,
gt_sample.gt_sem_seg, classes,
palette)
if draw_pred and pred_sample is not None:
pred_img_data = image
if 'pred_sem_seg' in pred_sample:
assert classes is not None, 'class information is ' \
'not provided when ' \
'visualizing semantic ' \
'segmentation results.'
pred_img_data = self._draw_sem_seg(pred_img_data,
pred_sample.pred_sem_seg,
classes, palette)
if gt_img_data is not None and pred_img_data is not None:
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
elif gt_img_data is not None:
drawn_img = gt_img_data
else:
drawn_img = pred_img_data
if show:
self.show(drawn_img, win_name=name, wait_time=wait_time)
else:
self.add_image(name, drawn_img, step)

View File

@ -0,0 +1,187 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
from unittest import TestCase
import cv2
import mmcv
import numpy as np
import torch
from mmengine.data import PixelData
from mmseg.data import SegDataSample
from mmseg.engine.visualization import SegLocalVisualizer
class TestSegLocalVisualizer(TestCase):
def test_add_datasample(self):
h = 10
w = 12
num_class = 2
out_file = 'out_file'
image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8')
# test gt_sem_seg
gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
gt_sem_seg = PixelData(**gt_sem_seg_data)
gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
seg_local_visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir')
seg_local_visualizer.dataset_meta = dict(
classes=('background', 'foreground'),
palette=[[120, 120, 120], [6, 230, 230]])
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
# test out_file
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
assert os.path.exists(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
drawn_img = cv2.imread(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
assert drawn_img.shape == (h, w, 3)
os.remove(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
os.rmdir('temp_dir' + '/vis_data/vis_image')
# test gt_instances and pred_instances
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample,
pred_seg_data_sample)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w * 2, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_gt=False)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_pred=False)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w, 3))
os.rmdir('temp_dir/vis_data')
os.rmdir('temp_dir')
def test_cityscapes_add_datasample(self):
h = 128
w = 256
num_class = 19
out_file = 'out_file_cityscapes'
image = mmcv.imread(
osp.join(
osp.dirname(__file__),
'../data/pseudo_cityscapes_dataset/leftImg8bit/frankfurt_000000_000294_leftImg8bit.png' # noqa
),
'color')
sem_seg = mmcv.imread(
osp.join(
osp.dirname(__file__),
'../data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_labelTrainIds.png' # noqa
),
'unchanged')
sem_seg = torch.unsqueeze(torch.from_numpy(sem_seg), 0)
gt_sem_seg_data = dict(data=sem_seg)
gt_sem_seg = PixelData(**gt_sem_seg_data)
gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
seg_local_visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir')
seg_local_visualizer.dataset_meta = dict(
classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign', 'vegetation', 'terrain',
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
'motorcycle', 'bicycle'),
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70],
[102, 102, 156], [190, 153, 153], [153, 153, 153],
[250, 170, 30], [220, 220, 0], [107, 142, 35],
[152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
[0, 80, 100], [0, 0, 230], [119, 11, 32]])
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
# test out_file
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
assert os.path.exists(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
drawn_img = cv2.imread(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
assert drawn_img.shape == (h, w, 3)
os.remove(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
os.rmdir('temp_dir/vis_data/vis_image')
# test gt_instances and pred_instances
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample,
pred_seg_data_sample)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w * 2, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_gt=False)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_pred=False)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w, 3))
os.rmdir('temp_dir/vis_data')
os.rmdir('temp_dir')
def _assert_image_and_shape(self, out_file, out_shape):
assert os.path.exists(out_file)
drawn_img = cv2.imread(out_file)
assert drawn_img.shape == out_shape
os.remove(out_file)
os.rmdir('temp_dir/vis_data/vis_image')

View File

@ -0,0 +1,62 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock
import torch
from mmengine.data import PixelData
from mmseg.data import SegDataSample
from mmseg.engine.hooks import SegVisualizationHook
from mmseg.engine.visualization import SegLocalVisualizer
class TestVisualizationHook(TestCase):
def setUp(self) -> None:
h = 288
w = 512
num_class = 2
SegLocalVisualizer.get_instance('visualizer')
SegLocalVisualizer.dataset_meta = dict(
classes=('background', 'foreground'),
palette=[[120, 120, 120], [6, 230, 230]])
data_sample = SegDataSample()
data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
self.data_batch = [{'data_sample': data_sample}] * 2
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
self.outputs = [pred_seg_data_sample] * 2
def test_after_iter(self):
runner = Mock()
runner.iter = 1
hook = SegVisualizationHook(draw=True, interval=1)
hook._after_iter(
runner, 1, self.data_batch, self.outputs, mode='train')
hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='val')
hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='test')
def test_after_val_iter(self):
runner = Mock()
runner.iter = 2
hook = SegVisualizationHook(interval=1)
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
hook = SegVisualizationHook(draw=True, interval=1)
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
hook = SegVisualizationHook(
draw=True, interval=1, show=True, wait_time=1)
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
def test_after_test_iter(self):
runner = Mock()
runner.iter = 3
hook = SegVisualizationHook(draw=True, interval=1)
hook.after_iter(runner, 1, self.data_batch, self.outputs)

View File

@ -1,48 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import warnings
from pathlib import Path
import os.path as osp
import mmcv
import numpy as np
from mmcv import Config, DictAction
from mmseg.registry import DATASETS
from mmseg.datasets import DATASETS
from mmseg.registry import VISUALIZERS
from mmseg.utils import register_all_modules
def parse_args():
parser = argparse.ArgumentParser(description='Browse a dataset')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--show-origin',
default=False,
action='store_true',
help='if True, omit all augmentation in pipeline,'
' show origin image and seg map')
parser.add_argument(
'--skip-type',
type=str,
nargs='+',
default=['DefaultFormatBundle', 'Normalize', 'Collect'],
help='skip some useless pipelineif `show-origin` is true, '
'all pipeline except `Load` will be skipped')
parser.add_argument(
'--output-dir',
default='./output',
default=None,
type=str,
help='If there is no display interface, you can save it')
parser.add_argument('--show', default=False, action='store_true')
parser.add_argument('--not-show', default=False, action='store_true')
parser.add_argument(
'--show-interval',
type=int,
default=999,
help='the interval of show (ms)')
parser.add_argument(
'--opacity',
type=float,
default=0.5,
help='the opacity of semantic map')
default=2,
help='the interval of show (s)')
parser.add_argument(
'--cfg-options',
nargs='+',
@ -57,124 +38,35 @@ def parse_args():
return args
def imshow_semantic(img,
seg,
class_names,
palette=None,
win_name='',
show=False,
wait_time=0,
out_file=None,
opacity=0.5):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
seg (Tensor): The semantic segmentation results to draw over
`img`.
class_names (list[str]): Names of each classes.
palette (list[list[int]]] | np.ndarray | None): The palette of
segmentation map. If None is given, random palette will be
generated. Default: None
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Returns:
img (Tensor): Only if not `show` or `out_file`
"""
img = mmcv.imread(img)
img = img.copy()
if palette is None:
palette = np.random.randint(0, 255, size=(len(class_names), 3))
palette = np.array(palette)
assert palette.shape[0] == len(class_names)
assert palette.shape[1] == 3
assert len(palette.shape) == 2
assert 0 < opacity <= 1.0
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# convert to BGR
color_seg = color_seg[..., ::-1]
img = img * (1 - opacity) + color_seg * opacity
img = img.astype(np.uint8)
# if out_file specified, do not show image in window
if out_file is not None:
show = False
if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)
if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img
def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
if show_origin is True:
# only keep pipeline of Loading data and ann
_data_cfg['pipeline'] = [
x for x in _data_cfg.pipeline if 'Load' in x['type']
]
else:
_data_cfg['pipeline'] = [
x for x in _data_cfg.pipeline if x['type'] not in skip_type
]
def retrieve_data_cfg(config_path, skip_type, cfg_options, show_origin=False):
cfg = Config.fromfile(config_path)
if cfg_options is not None:
cfg.merge_from_dict(cfg_options)
train_data_cfg = cfg.data.train
if isinstance(train_data_cfg, list):
for _data_cfg in train_data_cfg:
while 'dataset' in _data_cfg and _data_cfg[
'type'] != 'MultiImageMixDataset':
_data_cfg = _data_cfg['dataset']
if 'pipeline' in _data_cfg:
_retrieve_data_cfg(_data_cfg, skip_type, show_origin)
else:
raise ValueError
else:
while 'dataset' in train_data_cfg and train_data_cfg[
'type'] != 'MultiImageMixDataset':
train_data_cfg = train_data_cfg['dataset']
_retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
return cfg
def main():
args = parse_args()
cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options,
args.show_origin)
dataset = DATASETS.build(cfg.data.train)
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# register all modules in mmseg into the registries
register_all_modules()
dataset = DATASETS.build(cfg.train_dataloader.dataset)
visualizer = VISUALIZERS.build(cfg.visualizer)
visualizer.dataset_meta = dataset.METAINFO
progress_bar = mmcv.ProgressBar(len(dataset))
for item in dataset:
filename = os.path.join(args.output_dir,
Path(item['filename']).name
) if args.output_dir is not None else None
imshow_semantic(
item['img'],
item['gt_semantic_seg'],
dataset.CLASSES,
dataset.PALETTE,
show=args.show,
wait_time=args.show_interval,
out_file=filename,
opacity=args.opacity,
)
img = item['inputs'].permute(1, 2, 0).numpy()
data_sample = item['data_sample'].numpy()
img_path = osp.basename(item['data_sample'].img_path)
img = img[..., [2, 1, 0]] # bgr to rgb
visualizer.add_datasample(
osp.basename(img_path),
img,
data_sample,
show=not args.not_show,
wait_time=args.show_interval)
progress_bar.update()

View File

@ -19,6 +19,15 @@ def parse_args():
'--work-dir',
help=('if specified, the evaluation metric results will be dumped'
'into the directory as json'))
parser.add_argument(
'--show', action='store_true', help='show prediction results')
parser.add_argument(
'--show-dir',
help='directory where painted images will be saved. '
'If specified, it will be automatically saved '
'to the work_dir/timestamp/show_dir')
parser.add_argument(
'--wait-time', type=float, default=2, help='the interval of show (s)')
parser.add_argument(
'--cfg-options',
nargs='+',
@ -42,6 +51,26 @@ def parse_args():
return args
def trigger_visualization_hook(cfg, args):
default_hooks = cfg.default_hooks
if 'visualization' in default_hooks:
visualization_hook = default_hooks['visualization']
# Turn on visualization
visualization_hook['draw'] = True
if args.show:
visualization_hook['show'] = True
visualization_hook['wait_time'] = args.wait_time
if args.show_dir:
visualization_hook['test_out_dir'] = args.show_dir
else:
raise RuntimeError(
'VisualizationHook must be included in default_hooks.'
'refer to usage '
'"visualization=dict(type=\'VisualizationHook\')"')
return cfg
def main():
args = parse_args()
@ -66,6 +95,9 @@ def main():
cfg.load_from = args.checkpoint
if args.show or args.show_dir:
cfg = trigger_visualization_hook(cfg, args)
# build the runner from config
runner = Runner.from_cfg(cfg)