182 lines
7.3 KiB
Python
182 lines
7.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
from mmengine.dist import master_only
|
|
from mmengine.structures import PixelData
|
|
from mmengine.visualization import Visualizer
|
|
|
|
from mmseg.registry import VISUALIZERS
|
|
from mmseg.structures import SegDataSample
|
|
|
|
|
|
@VISUALIZERS.register_module()
|
|
class SegLocalVisualizer(Visualizer):
|
|
"""Local Visualizer.
|
|
|
|
Args:
|
|
name (str): Name of the instance. Defaults to 'visualizer'.
|
|
image (np.ndarray, optional): the origin image to draw. The format
|
|
should be RGB. Defaults to None.
|
|
vis_backends (list, optional): Visual backend config list.
|
|
Defaults to None.
|
|
save_dir (str, optional): Save file dir for all storage backends.
|
|
If it is None, the backend storage will not save any data.
|
|
alpha (int, float): The transparency of segmentation mask.
|
|
Defaults to 0.8.
|
|
|
|
Examples:
|
|
>>> import numpy as np
|
|
>>> import torch
|
|
>>> from mmengine.structures import PixelData
|
|
>>> from mmseg.data import SegDataSample
|
|
>>> from mmseg.engine.visualization import SegLocalVisualizer
|
|
|
|
>>> seg_local_visualizer = SegLocalVisualizer()
|
|
>>> image = np.random.randint(0, 256,
|
|
... size=(10, 12, 3)).astype('uint8')
|
|
>>> gt_sem_seg_data = dict(data=torch.randint(0, 2, (1, 10, 12)))
|
|
>>> gt_sem_seg = PixelData(**gt_sem_seg_data)
|
|
>>> gt_seg_data_sample = SegDataSample()
|
|
>>> gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
|
>>> seg_local_visualizer.dataset_meta = dict(
|
|
>>> classes=('background', 'foreground'),
|
|
>>> palette=[[120, 120, 120], [6, 230, 230]])
|
|
>>> seg_local_visualizer.add_datasample('visualizer_example',
|
|
... image, gt_seg_data_sample)
|
|
>>> seg_local_visualizer.add_datasample(
|
|
... 'visualizer_example', image,
|
|
... gt_seg_data_sample, show=True)
|
|
"""
|
|
|
|
def __init__(self,
|
|
name: str = 'visualizer',
|
|
image: Optional[np.ndarray] = None,
|
|
vis_backends: Optional[Dict] = None,
|
|
save_dir: Optional[str] = None,
|
|
alpha: float = 0.8,
|
|
**kwargs):
|
|
super().__init__(name, image, vis_backends, save_dir, **kwargs)
|
|
self.alpha = alpha
|
|
# Set default value. When calling
|
|
# `SegLocalVisualizer().dataset_meta=xxx`,
|
|
# it will override the default value.
|
|
self.dataset_meta = {}
|
|
|
|
def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
|
|
classes: Optional[Tuple[str]],
|
|
palette: Optional[List[List[int]]]) -> np.ndarray:
|
|
"""Draw semantic seg of GT or prediction.
|
|
|
|
Args:
|
|
image (np.ndarray): The image to draw.
|
|
sem_seg (:obj:`PixelData`): Data structure for
|
|
pixel-level annotations or predictions.
|
|
classes (Tuple[str], optional): Category information.
|
|
palette (List[List[int]], optional): The palette of
|
|
segmentation map.
|
|
|
|
Returns:
|
|
np.ndarray: the drawn image which channel is RGB.
|
|
"""
|
|
num_classes = len(classes)
|
|
|
|
sem_seg = sem_seg.cpu().data
|
|
ids = np.unique(sem_seg)[::-1]
|
|
legal_indices = ids < num_classes
|
|
ids = ids[legal_indices]
|
|
labels = np.array(ids, dtype=np.int64)
|
|
|
|
colors = [palette[label] for label in labels]
|
|
|
|
self.set_image(image)
|
|
|
|
# draw semantic masks
|
|
for label, color in zip(labels, colors):
|
|
self.draw_binary_masks(
|
|
sem_seg == label, colors=[color], alphas=self.alpha)
|
|
|
|
return self.get_image()
|
|
|
|
@master_only
|
|
def add_datasample(
|
|
self,
|
|
name: str,
|
|
image: np.ndarray,
|
|
data_sample: Optional[SegDataSample] = None,
|
|
draw_gt: bool = True,
|
|
draw_pred: bool = True,
|
|
show: bool = False,
|
|
wait_time: float = 0,
|
|
# TODO: Supported in mmengine's Viusalizer.
|
|
out_file: Optional[str] = None,
|
|
step: int = 0) -> None:
|
|
"""Draw datasample and save to all backends.
|
|
|
|
- If GT and prediction are plotted at the same time, they are
|
|
displayed in a stitched image where the left image is the
|
|
ground truth and the right image is the prediction.
|
|
- If ``show`` is True, all storage backends are ignored, and
|
|
the images will be displayed in a local window.
|
|
- If ``out_file`` is specified, the drawn image will be
|
|
saved to ``out_file``. it is usually used when the display
|
|
is not available.
|
|
|
|
Args:
|
|
name (str): The image identifier.
|
|
image (np.ndarray): The image to draw.
|
|
gt_sample (:obj:`SegDataSample`, optional): GT SegDataSample.
|
|
Defaults to None.
|
|
pred_sample (:obj:`SegDataSample`, optional): Prediction
|
|
SegDataSample. Defaults to None.
|
|
draw_gt (bool): Whether to draw GT SegDataSample. Default to True.
|
|
draw_pred (bool): Whether to draw Prediction SegDataSample.
|
|
Defaults to True.
|
|
show (bool): Whether to display the drawn image. Default to False.
|
|
wait_time (float): The interval of show (s). Defaults to 0.
|
|
out_file (str): Path to output file. Defaults to None.
|
|
step (int): Global step value to record. Defaults to 0.
|
|
"""
|
|
classes = self.dataset_meta.get('classes', None)
|
|
palette = self.dataset_meta.get('palette', None)
|
|
|
|
gt_img_data = None
|
|
pred_img_data = None
|
|
|
|
if draw_gt and data_sample is not None and 'gt_sem_seg' in data_sample:
|
|
gt_img_data = image
|
|
assert classes is not None, 'class information is ' \
|
|
'not provided when ' \
|
|
'visualizing semantic ' \
|
|
'segmentation results.'
|
|
gt_img_data = self._draw_sem_seg(gt_img_data,
|
|
data_sample.gt_sem_seg, classes,
|
|
palette)
|
|
|
|
if (draw_pred and data_sample is not None
|
|
and 'pred_sem_seg' in data_sample):
|
|
pred_img_data = image
|
|
assert classes is not None, 'class information is ' \
|
|
'not provided when ' \
|
|
'visualizing semantic ' \
|
|
'segmentation results.'
|
|
pred_img_data = self._draw_sem_seg(pred_img_data,
|
|
data_sample.pred_sem_seg,
|
|
classes, palette)
|
|
|
|
if gt_img_data is not None and pred_img_data is not None:
|
|
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
|
|
elif gt_img_data is not None:
|
|
drawn_img = gt_img_data
|
|
else:
|
|
drawn_img = pred_img_data
|
|
|
|
if show:
|
|
self.show(drawn_img, win_name=name, wait_time=wait_time)
|
|
|
|
if out_file is not None:
|
|
mmcv.imwrite(drawn_img, out_file)
|
|
else:
|
|
self.add_image(name, drawn_img, step)
|