mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Add SegVisualizer (#1792)
* [Feature] Add SegVisualizer * change name to visualizer_example * fix typo * refactor folder structure
This commit is contained in:
parent
4079d6dfed
commit
6873f9ece8
@ -1,7 +1,10 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .hooks import SegVisualizationHook
|
||||
from .optimizers import (LayerDecayOptimizerConstructor,
|
||||
LearningRateDecayOptimizerConstructor)
|
||||
from .visualization import SegLocalVisualizer
|
||||
|
||||
__all__ = [
|
||||
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor'
|
||||
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
|
||||
'SegVisualizationHook', 'SegLocalVisualizer'
|
||||
]
|
||||
|
4
mmseg/engine/hooks/__init__.py
Normal file
4
mmseg/engine/hooks/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .visualization_hook import SegVisualizationHook
|
||||
|
||||
__all__ = ['SegVisualizationHook']
|
101
mmseg/engine/hooks/visualization_hook.py
Normal file
101
mmseg/engine/hooks/visualization_hook.py
Normal file
@ -0,0 +1,101 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import Sequence
|
||||
|
||||
import mmcv
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.engine.visualization import SegLocalVisualizer
|
||||
from mmseg.registry import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class SegVisualizationHook(Hook):
|
||||
"""Segmentation Visualization Hook. Used to visualize validation and
|
||||
testing process prediction results.
|
||||
|
||||
In the testing phase:
|
||||
|
||||
1. If ``show`` is True, it means that only the prediction results are
|
||||
visualized without storing data, so ``vis_backends`` needs to
|
||||
be excluded.
|
||||
|
||||
Args:
|
||||
draw (bool): whether to draw prediction results. If it is False,
|
||||
it means that no drawing will be done. Defaults to False.
|
||||
interval (int): The interval of visualization. Defaults to 50.
|
||||
show (bool): Whether to display the drawn image. Default to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
file_client_args (dict): Arguments to instantiate a FileClient.
|
||||
See :class:`mmcv.fileio.FileClient` for details.
|
||||
Defaults to ``dict(backend='disk')``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
draw: bool = False,
|
||||
interval: int = 50,
|
||||
show: bool = False,
|
||||
wait_time: float = 0.,
|
||||
file_client_args: dict = dict(backend='disk')):
|
||||
self._visualizer: SegLocalVisualizer = \
|
||||
SegLocalVisualizer.get_current_instance()
|
||||
self.interval = interval
|
||||
self.show = show
|
||||
if self.show:
|
||||
# No need to think about vis backends.
|
||||
self._visualizer._vis_backends = {}
|
||||
warnings.warn('The show is True, it means that only '
|
||||
'the prediction results are visualized '
|
||||
'without storing data, so vis_backends '
|
||||
'needs to be excluded.')
|
||||
|
||||
self.wait_time = wait_time
|
||||
self.file_client_args = file_client_args.copy()
|
||||
self.file_client = None
|
||||
self.draw = draw
|
||||
if not self.draw:
|
||||
warnings.warn('The draw is False, it means that the '
|
||||
'hook for visualization will not take '
|
||||
'effect. The results will NOT be '
|
||||
'visualized or stored.')
|
||||
|
||||
def after_iter(self,
|
||||
runner: Runner,
|
||||
batch_idx: int,
|
||||
data_batch: Sequence[dict],
|
||||
outputs: Sequence[SegDataSample],
|
||||
mode: str = 'val') -> None:
|
||||
"""Run after every ``self.interval`` validation iterations.
|
||||
|
||||
Args:
|
||||
runner (:obj:`Runner`): The runner of the validation process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (Sequence[dict]): Data from dataloader.
|
||||
outputs (Sequence[:obj:`SegDataSample`]): Outputs from model.
|
||||
mode (str): mode (str): Current mode of runner. Defaults to 'val'.
|
||||
"""
|
||||
if self.draw is False or mode == 'train':
|
||||
return
|
||||
|
||||
if self.file_client is None:
|
||||
self.file_client = mmcv.FileClient(**self.file_client_args)
|
||||
|
||||
if self.every_n_inner_iters(batch_idx, self.interval):
|
||||
for input_data, output in zip(data_batch, outputs):
|
||||
img_path = input_data['data_sample'].img_path
|
||||
img_bytes = self.file_client.get(img_path)
|
||||
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
|
||||
window_name = f'{mode}_{osp.basename(img_path)}'
|
||||
|
||||
gt_sample = input_data['data_sample']
|
||||
self._visualizer.add_datasample(
|
||||
window_name,
|
||||
img,
|
||||
gt_sample=gt_sample,
|
||||
pred_sample=output,
|
||||
show=self.show,
|
||||
wait_time=self.wait_time,
|
||||
step=runner.iter)
|
4
mmseg/engine/visualization/__init__.py
Normal file
4
mmseg/engine/visualization/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .local_visualizer import SegLocalVisualizer
|
||||
|
||||
__all__ = ['SegLocalVisualizer']
|
172
mmseg/engine/visualization/local_visualizer.py
Normal file
172
mmseg/engine/visualization/local_visualizer.py
Normal file
@ -0,0 +1,172 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from mmengine import Visualizer
|
||||
from mmengine.data import PixelData
|
||||
from mmengine.dist import master_only
|
||||
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.registry import VISUALIZERS
|
||||
|
||||
|
||||
@VISUALIZERS.register_module()
|
||||
class SegLocalVisualizer(Visualizer):
|
||||
"""MMSegmentation Local Visualizer.
|
||||
|
||||
Args:
|
||||
name (str): Name of the instance. Defaults to 'visualizer'.
|
||||
image (np.ndarray, optional): the origin image to draw. The format
|
||||
should be RGB. Defaults to None.
|
||||
vis_backends (list, optional): Visual backend config list.
|
||||
Defaults to None.
|
||||
save_dir (str, optional): Save file dir for all storage backends.
|
||||
If it is None, the backend storage will not save any data.
|
||||
alpha (int, float): The transparency of segmentation mask.
|
||||
Defaults to 0.8.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
>>> from mmengine.data import PixelData
|
||||
>>> from mmseg.data import SegDataSample
|
||||
>>> from mmseg.engine.visualization import SegLocalVisualizer
|
||||
|
||||
>>> seg_local_visualizer = SegLocalVisualizer()
|
||||
>>> image = np.random.randint(0, 256,
|
||||
... size=(10, 12, 3)).astype('uint8')
|
||||
>>> gt_sem_seg_data = dict(data=torch.randint(0, 2, (1, 10, 12)))
|
||||
>>> gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
>>> gt_seg_data_sample = SegDataSample()
|
||||
>>> gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
||||
>>> seg_local_visualizer.dataset_meta = dict(
|
||||
>>> classes=('background', 'foreground'),
|
||||
>>> palette=[[120, 120, 120], [6, 230, 230]])
|
||||
>>> seg_local_visualizer.add_datasample('visualizer_example',
|
||||
... image, gt_seg_data_sample)
|
||||
>>> seg_local_visualizer.add_datasample(
|
||||
... 'visualizer_example', image,
|
||||
... gt_seg_data_sample, show=True)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name: str = 'visualizer',
|
||||
image: Optional[np.ndarray] = None,
|
||||
vis_backends: Optional[Dict] = None,
|
||||
save_dir: Optional[str] = None,
|
||||
alpha: float = 0.8,
|
||||
**kwargs):
|
||||
super().__init__(name, image, vis_backends, save_dir, **kwargs)
|
||||
self.alpha = alpha
|
||||
# Set default value. When calling
|
||||
# `SegLocalVisualizer().dataset_meta=xxx`,
|
||||
# it will override the default value.
|
||||
self.dataset_meta = {}
|
||||
|
||||
def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
|
||||
classes: Optional[Tuple[str]],
|
||||
palette: Optional[List[List[int]]]) -> np.ndarray:
|
||||
"""Draw semantic seg of GT or prediction.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The image to draw.
|
||||
sem_seg (:obj:`PixelData`): Data structure for
|
||||
pixel-level annotations or predictions.
|
||||
classes (Tuple[str], optional): Category information.
|
||||
palette (List[List[int]], optional): The palette of
|
||||
segmentation map.
|
||||
|
||||
Returns:
|
||||
np.ndarray: the drawn image which channel is RGB.
|
||||
"""
|
||||
num_classes = len(classes)
|
||||
|
||||
sem_seg = sem_seg.data
|
||||
ids = np.unique(sem_seg)[::-1]
|
||||
legal_indices = ids < num_classes
|
||||
ids = ids[legal_indices]
|
||||
labels = np.array(ids, dtype=np.int64)
|
||||
|
||||
colors = [palette[label] for label in labels]
|
||||
|
||||
self.set_image(image)
|
||||
|
||||
# draw semantic masks
|
||||
for label, color in zip(labels, colors):
|
||||
self.draw_binary_masks(
|
||||
sem_seg == label, colors=[color], alphas=self.alpha)
|
||||
|
||||
return self.get_image()
|
||||
|
||||
@master_only
|
||||
def add_datasample(self,
|
||||
name: str,
|
||||
image: np.ndarray,
|
||||
gt_sample: Optional[SegDataSample] = None,
|
||||
pred_sample: Optional[SegDataSample] = None,
|
||||
draw_gt: bool = True,
|
||||
draw_pred: bool = True,
|
||||
show: bool = False,
|
||||
wait_time: float = 0,
|
||||
step: int = 0) -> None:
|
||||
"""Draw datasample and save to all backends.
|
||||
|
||||
- If GT and prediction are plotted at the same time, they are
|
||||
displayed in a stitched image where the left image is the
|
||||
ground truth and the right image is the prediction.
|
||||
- If ``show`` is True, all storage backends are ignored, and
|
||||
the images will be displayed in a local window.
|
||||
|
||||
Args:
|
||||
name (str): The image identifier.
|
||||
image (np.ndarray): The image to draw.
|
||||
gt_sample (:obj:`SegDataSample`, optional): GT SegDataSample.
|
||||
Defaults to None.
|
||||
pred_sample (:obj:`SegDataSample`, optional): Prediction
|
||||
SegDataSample. Defaults to None.
|
||||
draw_gt (bool): Whether to draw GT SegDataSample. Default to True.
|
||||
draw_pred (bool): Whether to draw Prediction SegDataSample.
|
||||
Defaults to True.
|
||||
show (bool): Whether to display the drawn image. Default to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
step (int): Global step value to record. Defaults to 0.
|
||||
"""
|
||||
classes = self.dataset_meta.get('classes', None)
|
||||
palette = self.dataset_meta.get('palette', None)
|
||||
|
||||
gt_img_data = None
|
||||
pred_img_data = None
|
||||
|
||||
if draw_gt and gt_sample is not None:
|
||||
gt_img_data = image
|
||||
if 'gt_sem_seg' in gt_sample:
|
||||
assert classes is not None, 'class information is ' \
|
||||
'not provided when ' \
|
||||
'visualizing semantic ' \
|
||||
'segmentation results.'
|
||||
gt_img_data = self._draw_sem_seg(gt_img_data,
|
||||
gt_sample.gt_sem_seg, classes,
|
||||
palette)
|
||||
|
||||
if draw_pred and pred_sample is not None:
|
||||
pred_img_data = image
|
||||
if 'pred_sem_seg' in pred_sample:
|
||||
assert classes is not None, 'class information is ' \
|
||||
'not provided when ' \
|
||||
'visualizing semantic ' \
|
||||
'segmentation results.'
|
||||
pred_img_data = self._draw_sem_seg(pred_img_data,
|
||||
pred_sample.pred_sem_seg,
|
||||
classes, palette)
|
||||
|
||||
if gt_img_data is not None and pred_img_data is not None:
|
||||
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
|
||||
elif gt_img_data is not None:
|
||||
drawn_img = gt_img_data
|
||||
else:
|
||||
drawn_img = pred_img_data
|
||||
|
||||
if show:
|
||||
self.show(drawn_img, win_name=name, wait_time=wait_time)
|
||||
else:
|
||||
self.add_image(name, drawn_img, step)
|
187
tests/test_engine/test_local_visualizer.py
Normal file
187
tests/test_engine/test_local_visualizer.py
Normal file
@ -0,0 +1,187 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
from unittest import TestCase
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.data import PixelData
|
||||
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.engine.visualization import SegLocalVisualizer
|
||||
|
||||
|
||||
class TestSegLocalVisualizer(TestCase):
|
||||
|
||||
def test_add_datasample(self):
|
||||
h = 10
|
||||
w = 12
|
||||
num_class = 2
|
||||
out_file = 'out_file'
|
||||
|
||||
image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8')
|
||||
|
||||
# test gt_sem_seg
|
||||
gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
|
||||
gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
|
||||
gt_seg_data_sample = SegDataSample()
|
||||
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
||||
|
||||
seg_local_visualizer = SegLocalVisualizer(
|
||||
vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir')
|
||||
seg_local_visualizer.dataset_meta = dict(
|
||||
classes=('background', 'foreground'),
|
||||
palette=[[120, 120, 120], [6, 230, 230]])
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample)
|
||||
|
||||
# test out_file
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample)
|
||||
|
||||
assert os.path.exists(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
||||
drawn_img = cv2.imread(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
||||
assert drawn_img.shape == (h, w, 3)
|
||||
|
||||
os.remove(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
||||
os.rmdir('temp_dir' + '/vis_data/vis_image')
|
||||
|
||||
# test gt_instances and pred_instances
|
||||
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
|
||||
pred_sem_seg = PixelData(**pred_sem_seg_data)
|
||||
|
||||
pred_seg_data_sample = SegDataSample()
|
||||
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
|
||||
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample)
|
||||
self._assert_image_and_shape(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
||||
(h, w * 2, 3))
|
||||
|
||||
seg_local_visualizer.add_datasample(
|
||||
out_file,
|
||||
image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample,
|
||||
draw_gt=False)
|
||||
self._assert_image_and_shape(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
||||
(h, w, 3))
|
||||
|
||||
seg_local_visualizer.add_datasample(
|
||||
out_file,
|
||||
image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample,
|
||||
draw_pred=False)
|
||||
self._assert_image_and_shape(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
||||
(h, w, 3))
|
||||
os.rmdir('temp_dir/vis_data')
|
||||
os.rmdir('temp_dir')
|
||||
|
||||
def test_cityscapes_add_datasample(self):
|
||||
h = 128
|
||||
w = 256
|
||||
num_class = 19
|
||||
out_file = 'out_file_cityscapes'
|
||||
|
||||
image = mmcv.imread(
|
||||
osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_cityscapes_dataset/leftImg8bit/frankfurt_000000_000294_leftImg8bit.png' # noqa
|
||||
),
|
||||
'color')
|
||||
sem_seg = mmcv.imread(
|
||||
osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_labelTrainIds.png' # noqa
|
||||
),
|
||||
'unchanged')
|
||||
sem_seg = torch.unsqueeze(torch.from_numpy(sem_seg), 0)
|
||||
gt_sem_seg_data = dict(data=sem_seg)
|
||||
gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
|
||||
gt_seg_data_sample = SegDataSample()
|
||||
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
||||
|
||||
seg_local_visualizer = SegLocalVisualizer(
|
||||
vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir')
|
||||
seg_local_visualizer.dataset_meta = dict(
|
||||
classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
||||
'traffic light', 'traffic sign', 'vegetation', 'terrain',
|
||||
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
|
||||
'motorcycle', 'bicycle'),
|
||||
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70],
|
||||
[102, 102, 156], [190, 153, 153], [153, 153, 153],
|
||||
[250, 170, 30], [220, 220, 0], [107, 142, 35],
|
||||
[152, 251, 152], [70, 130, 180], [220, 20, 60],
|
||||
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
|
||||
[0, 80, 100], [0, 0, 230], [119, 11, 32]])
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample)
|
||||
|
||||
# test out_file
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample)
|
||||
assert os.path.exists(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
||||
drawn_img = cv2.imread(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
||||
assert drawn_img.shape == (h, w, 3)
|
||||
|
||||
os.remove(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
||||
os.rmdir('temp_dir/vis_data/vis_image')
|
||||
|
||||
# test gt_instances and pred_instances
|
||||
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
|
||||
pred_sem_seg = PixelData(**pred_sem_seg_data)
|
||||
|
||||
pred_seg_data_sample = SegDataSample()
|
||||
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
|
||||
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample)
|
||||
self._assert_image_and_shape(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
||||
(h, w * 2, 3))
|
||||
|
||||
seg_local_visualizer.add_datasample(
|
||||
out_file,
|
||||
image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample,
|
||||
draw_gt=False)
|
||||
self._assert_image_and_shape(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
||||
(h, w, 3))
|
||||
|
||||
seg_local_visualizer.add_datasample(
|
||||
out_file,
|
||||
image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample,
|
||||
draw_pred=False)
|
||||
self._assert_image_and_shape(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
||||
(h, w, 3))
|
||||
|
||||
os.rmdir('temp_dir/vis_data')
|
||||
os.rmdir('temp_dir')
|
||||
|
||||
def _assert_image_and_shape(self, out_file, out_shape):
|
||||
assert os.path.exists(out_file)
|
||||
drawn_img = cv2.imread(out_file)
|
||||
assert drawn_img.shape == out_shape
|
||||
os.remove(out_file)
|
||||
os.rmdir('temp_dir/vis_data/vis_image')
|
62
tests/test_engine/test_visualization_hook.py
Normal file
62
tests/test_engine/test_visualization_hook.py
Normal file
@ -0,0 +1,62 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
from unittest.mock import Mock
|
||||
|
||||
import torch
|
||||
from mmengine.data import PixelData
|
||||
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.engine.hooks import SegVisualizationHook
|
||||
from mmseg.engine.visualization import SegLocalVisualizer
|
||||
|
||||
|
||||
class TestVisualizationHook(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
|
||||
h = 288
|
||||
w = 512
|
||||
num_class = 2
|
||||
|
||||
SegLocalVisualizer.get_instance('visualizer')
|
||||
SegLocalVisualizer.dataset_meta = dict(
|
||||
classes=('background', 'foreground'),
|
||||
palette=[[120, 120, 120], [6, 230, 230]])
|
||||
|
||||
data_sample = SegDataSample()
|
||||
data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
|
||||
self.data_batch = [{'data_sample': data_sample}] * 2
|
||||
|
||||
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
|
||||
pred_sem_seg = PixelData(**pred_sem_seg_data)
|
||||
pred_seg_data_sample = SegDataSample()
|
||||
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
|
||||
self.outputs = [pred_seg_data_sample] * 2
|
||||
|
||||
def test_after_iter(self):
|
||||
runner = Mock()
|
||||
runner.iter = 1
|
||||
hook = SegVisualizationHook(draw=True, interval=1)
|
||||
hook._after_iter(
|
||||
runner, 1, self.data_batch, self.outputs, mode='train')
|
||||
hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='val')
|
||||
hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='test')
|
||||
|
||||
def test_after_val_iter(self):
|
||||
runner = Mock()
|
||||
runner.iter = 2
|
||||
hook = SegVisualizationHook(interval=1)
|
||||
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
|
||||
|
||||
hook = SegVisualizationHook(draw=True, interval=1)
|
||||
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
|
||||
|
||||
hook = SegVisualizationHook(
|
||||
draw=True, interval=1, show=True, wait_time=1)
|
||||
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
|
||||
|
||||
def test_after_test_iter(self):
|
||||
runner = Mock()
|
||||
runner.iter = 3
|
||||
hook = SegVisualizationHook(draw=True, interval=1)
|
||||
hook.after_iter(runner, 1, self.data_batch, self.outputs)
|
@ -1,48 +1,29 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv import Config, DictAction
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from mmseg.datasets import DATASETS
|
||||
from mmseg.registry import VISUALIZERS
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Browse a dataset')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument(
|
||||
'--show-origin',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='if True, omit all augmentation in pipeline,'
|
||||
' show origin image and seg map')
|
||||
parser.add_argument(
|
||||
'--skip-type',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=['DefaultFormatBundle', 'Normalize', 'Collect'],
|
||||
help='skip some useless pipeline,if `show-origin` is true, '
|
||||
'all pipeline except `Load` will be skipped')
|
||||
parser.add_argument(
|
||||
'--output-dir',
|
||||
default='./output',
|
||||
default=None,
|
||||
type=str,
|
||||
help='If there is no display interface, you can save it')
|
||||
parser.add_argument('--show', default=False, action='store_true')
|
||||
parser.add_argument('--not-show', default=False, action='store_true')
|
||||
parser.add_argument(
|
||||
'--show-interval',
|
||||
type=int,
|
||||
default=999,
|
||||
help='the interval of show (ms)')
|
||||
parser.add_argument(
|
||||
'--opacity',
|
||||
type=float,
|
||||
default=0.5,
|
||||
help='the opacity of semantic map')
|
||||
default=2,
|
||||
help='the interval of show (s)')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
@ -57,124 +38,35 @@ def parse_args():
|
||||
return args
|
||||
|
||||
|
||||
def imshow_semantic(img,
|
||||
seg,
|
||||
class_names,
|
||||
palette=None,
|
||||
win_name='',
|
||||
show=False,
|
||||
wait_time=0,
|
||||
out_file=None,
|
||||
opacity=0.5):
|
||||
"""Draw `result` over `img`.
|
||||
|
||||
Args:
|
||||
img (str or Tensor): The image to be displayed.
|
||||
seg (Tensor): The semantic segmentation results to draw over
|
||||
`img`.
|
||||
class_names (list[str]): Names of each classes.
|
||||
palette (list[list[int]]] | np.ndarray | None): The palette of
|
||||
segmentation map. If None is given, random palette will be
|
||||
generated. Default: None
|
||||
win_name (str): The window name.
|
||||
wait_time (int): Value of waitKey param.
|
||||
Default: 0.
|
||||
show (bool): Whether to show the image.
|
||||
Default: False.
|
||||
out_file (str or None): The filename to write the image.
|
||||
Default: None.
|
||||
opacity(float): Opacity of painted segmentation map.
|
||||
Default 0.5.
|
||||
Must be in (0, 1] range.
|
||||
Returns:
|
||||
img (Tensor): Only if not `show` or `out_file`
|
||||
"""
|
||||
img = mmcv.imread(img)
|
||||
img = img.copy()
|
||||
if palette is None:
|
||||
palette = np.random.randint(0, 255, size=(len(class_names), 3))
|
||||
palette = np.array(palette)
|
||||
assert palette.shape[0] == len(class_names)
|
||||
assert palette.shape[1] == 3
|
||||
assert len(palette.shape) == 2
|
||||
assert 0 < opacity <= 1.0
|
||||
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
||||
for label, color in enumerate(palette):
|
||||
color_seg[seg == label, :] = color
|
||||
# convert to BGR
|
||||
color_seg = color_seg[..., ::-1]
|
||||
|
||||
img = img * (1 - opacity) + color_seg * opacity
|
||||
img = img.astype(np.uint8)
|
||||
# if out_file specified, do not show image in window
|
||||
if out_file is not None:
|
||||
show = False
|
||||
|
||||
if show:
|
||||
mmcv.imshow(img, win_name, wait_time)
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(img, out_file)
|
||||
|
||||
if not (show or out_file):
|
||||
warnings.warn('show==False and out_file is not specified, only '
|
||||
'result image will be returned')
|
||||
return img
|
||||
|
||||
|
||||
def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
|
||||
if show_origin is True:
|
||||
# only keep pipeline of Loading data and ann
|
||||
_data_cfg['pipeline'] = [
|
||||
x for x in _data_cfg.pipeline if 'Load' in x['type']
|
||||
]
|
||||
else:
|
||||
_data_cfg['pipeline'] = [
|
||||
x for x in _data_cfg.pipeline if x['type'] not in skip_type
|
||||
]
|
||||
|
||||
|
||||
def retrieve_data_cfg(config_path, skip_type, cfg_options, show_origin=False):
|
||||
cfg = Config.fromfile(config_path)
|
||||
if cfg_options is not None:
|
||||
cfg.merge_from_dict(cfg_options)
|
||||
train_data_cfg = cfg.data.train
|
||||
if isinstance(train_data_cfg, list):
|
||||
for _data_cfg in train_data_cfg:
|
||||
while 'dataset' in _data_cfg and _data_cfg[
|
||||
'type'] != 'MultiImageMixDataset':
|
||||
_data_cfg = _data_cfg['dataset']
|
||||
if 'pipeline' in _data_cfg:
|
||||
_retrieve_data_cfg(_data_cfg, skip_type, show_origin)
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
while 'dataset' in train_data_cfg and train_data_cfg[
|
||||
'type'] != 'MultiImageMixDataset':
|
||||
train_data_cfg = train_data_cfg['dataset']
|
||||
_retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
|
||||
return cfg
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options,
|
||||
args.show_origin)
|
||||
dataset = DATASETS.build(cfg.data.train)
|
||||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# register all modules in mmseg into the registries
|
||||
register_all_modules()
|
||||
|
||||
dataset = DATASETS.build(cfg.train_dataloader.dataset)
|
||||
|
||||
visualizer = VISUALIZERS.build(cfg.visualizer)
|
||||
visualizer.dataset_meta = dataset.METAINFO
|
||||
|
||||
progress_bar = mmcv.ProgressBar(len(dataset))
|
||||
for item in dataset:
|
||||
filename = os.path.join(args.output_dir,
|
||||
Path(item['filename']).name
|
||||
) if args.output_dir is not None else None
|
||||
imshow_semantic(
|
||||
item['img'],
|
||||
item['gt_semantic_seg'],
|
||||
dataset.CLASSES,
|
||||
dataset.PALETTE,
|
||||
show=args.show,
|
||||
wait_time=args.show_interval,
|
||||
out_file=filename,
|
||||
opacity=args.opacity,
|
||||
)
|
||||
img = item['inputs'].permute(1, 2, 0).numpy()
|
||||
data_sample = item['data_sample'].numpy()
|
||||
img_path = osp.basename(item['data_sample'].img_path)
|
||||
|
||||
img = img[..., [2, 1, 0]] # bgr to rgb
|
||||
|
||||
visualizer.add_datasample(
|
||||
osp.basename(img_path),
|
||||
img,
|
||||
data_sample,
|
||||
show=not args.not_show,
|
||||
wait_time=args.show_interval)
|
||||
|
||||
progress_bar.update()
|
||||
|
||||
|
||||
|
@ -19,6 +19,15 @@ def parse_args():
|
||||
'--work-dir',
|
||||
help=('if specified, the evaluation metric results will be dumped'
|
||||
'into the directory as json'))
|
||||
parser.add_argument(
|
||||
'--show', action='store_true', help='show prediction results')
|
||||
parser.add_argument(
|
||||
'--show-dir',
|
||||
help='directory where painted images will be saved. '
|
||||
'If specified, it will be automatically saved '
|
||||
'to the work_dir/timestamp/show_dir')
|
||||
parser.add_argument(
|
||||
'--wait-time', type=float, default=2, help='the interval of show (s)')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
@ -42,6 +51,26 @@ def parse_args():
|
||||
return args
|
||||
|
||||
|
||||
def trigger_visualization_hook(cfg, args):
|
||||
default_hooks = cfg.default_hooks
|
||||
if 'visualization' in default_hooks:
|
||||
visualization_hook = default_hooks['visualization']
|
||||
# Turn on visualization
|
||||
visualization_hook['draw'] = True
|
||||
if args.show:
|
||||
visualization_hook['show'] = True
|
||||
visualization_hook['wait_time'] = args.wait_time
|
||||
if args.show_dir:
|
||||
visualization_hook['test_out_dir'] = args.show_dir
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'VisualizationHook must be included in default_hooks.'
|
||||
'refer to usage '
|
||||
'"visualization=dict(type=\'VisualizationHook\')"')
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
@ -66,6 +95,9 @@ def main():
|
||||
|
||||
cfg.load_from = args.checkpoint
|
||||
|
||||
if args.show or args.show_dir:
|
||||
cfg = trigger_visualization_hook(cfg, args)
|
||||
|
||||
# build the runner from config
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user