189 lines
7.8 KiB
Python
189 lines
7.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Optional, Tuple
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
from mmengine import Visualizer
|
|
from mmengine.dist import master_only
|
|
|
|
from mmcls.registry import VISUALIZERS
|
|
from mmcls.structures import ClsDataSample
|
|
|
|
|
|
def _get_adaptive_scale(img_shape: Tuple[int, int],
|
|
min_scale: float = 0.3,
|
|
max_scale: float = 3.0) -> float:
|
|
"""Get adaptive scale according to image shape.
|
|
|
|
The target scale depends on the the short edge length of the image. If the
|
|
short edge length equals 224, the output is 1.0. And output linear scales
|
|
according the short edge length.
|
|
|
|
You can also specify the minimum scale and the maximum scale to limit the
|
|
linear scale.
|
|
|
|
Args:
|
|
img_shape (Tuple[int, int]): The shape of the canvas image.
|
|
min_size (int): The minimum scale. Defaults to 0.3.
|
|
max_size (int): The maximum scale. Defaults to 3.0.
|
|
|
|
Returns:
|
|
int: The adaptive scale.
|
|
"""
|
|
short_edge_length = min(img_shape)
|
|
scale = short_edge_length / 224.
|
|
return min(max(scale, min_scale), max_scale)
|
|
|
|
|
|
@VISUALIZERS.register_module()
|
|
class ClsVisualizer(Visualizer):
|
|
"""Universal Visualizer for classification task.
|
|
|
|
Args:
|
|
name (str): Name of the instance. Defaults to 'visualizer'.
|
|
image (np.ndarray, optional): the origin image to draw. The format
|
|
should be RGB. Defaults to None.
|
|
vis_backends (list, optional): Visual backend config list.
|
|
Defaults to None.
|
|
save_dir (str, optional): Save file dir for all storage backends.
|
|
If it is None, the backend storage will not save any data.
|
|
fig_save_cfg (dict): Keyword parameters of figure for saving.
|
|
Defaults to empty dict.
|
|
fig_show_cfg (dict): Keyword parameters of figure for showing.
|
|
Defaults to empty dict.
|
|
|
|
Examples:
|
|
>>> import torch
|
|
>>> import mmcv
|
|
>>> from pathlib import Path
|
|
>>> from mmcls.visualization import ClsVisualizer
|
|
>>> from mmcls.structures import ClsDataSample
|
|
>>> # Example image
|
|
>>> img = mmcv.imread("./demo/bird.JPEG", channel_order='rgb')
|
|
>>> # Example annotation
|
|
>>> data_sample = ClsDataSample().set_gt_label(1).set_pred_label(1).\
|
|
... set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
|
|
>>> # Setup the visualizer
|
|
>>> vis = ClsVisualizer(
|
|
... save_dir="./outputs",
|
|
... vis_backends=[dict(type='LocalVisBackend')])
|
|
>>> # Set classes names
|
|
>>> vis.dataset_meta = {'classes': ['cat', 'bird', 'dog']}
|
|
>>> # Show the example image with annotation in a figure.
|
|
>>> # And it will ignore all preset storage backends.
|
|
>>> vis.add_datasample('res', img, data_sample, show=True)
|
|
>>> # Save the visualization result by the specified storage backends.
|
|
>>> vis.add_datasample('res', img, data_sample)
|
|
>>> assert Path('./outputs/vis_data/vis_image/res_0.png').exists()
|
|
>>> # Save another visualization result with the same name.
|
|
>>> vis.add_datasample('res', img, data_sample, step=1)
|
|
>>> assert Path('./outputs/vis_data/vis_image/res_1.png').exists()
|
|
"""
|
|
|
|
@master_only
|
|
def add_datasample(self,
|
|
name: str,
|
|
image: np.ndarray,
|
|
data_sample: Optional[ClsDataSample] = None,
|
|
draw_gt: bool = True,
|
|
draw_pred: bool = True,
|
|
draw_score: bool = True,
|
|
rescale_factor: Optional[float] = None,
|
|
show: bool = False,
|
|
text_cfg: dict = dict(),
|
|
wait_time: float = 0,
|
|
out_file: Optional[str] = None,
|
|
step: int = 0) -> None:
|
|
"""Draw datasample and save to all backends.
|
|
|
|
- If ``out_file`` is specified, all storage backends are ignored
|
|
and save the image to the ``out_file``.
|
|
- If ``show`` is True, plot the result image in a window, please
|
|
confirm you are able to access the graphical interface.
|
|
|
|
Args:
|
|
name (str): The image identifier.
|
|
image (np.ndarray): The image to draw.
|
|
data_sample (:obj:`ClsDataSample`, optional): The annotation of the
|
|
image. Defaults to None.
|
|
draw_gt (bool): Whether to draw ground truth labels.
|
|
Defaults to True.
|
|
draw_pred (bool): Whether to draw prediction labels.
|
|
Defaults to True.
|
|
draw_score (bool): Whether to draw the prediction scores
|
|
of prediction categories. Defaults to True.
|
|
rescale_factor (float, optional): Rescale the image by the rescale
|
|
factor before visualization. Defaults to None.
|
|
show (bool): Whether to display the drawn image. Defaults to False.
|
|
text_cfg (dict): Extra text setting, which accepts
|
|
arguments of :attr:`mmengine.Visualizer.draw_texts`.
|
|
Defaults to an empty dict.
|
|
wait_time (float): The interval of show (s). Defaults to 0, which
|
|
means "forever".
|
|
out_file (str, optional): Extra path to save the visualization
|
|
result. If specified, the visualizer will only save the result
|
|
image to the out_file and ignore its storage backends.
|
|
Defaults to None.
|
|
step (int): Global step value to record. Defaults to 0.
|
|
"""
|
|
classes = None
|
|
if self.dataset_meta is not None:
|
|
classes = self.dataset_meta.get('classes', None)
|
|
|
|
if rescale_factor is not None:
|
|
image = mmcv.imrescale(image, rescale_factor)
|
|
|
|
texts = []
|
|
self.set_image(image)
|
|
|
|
if draw_gt and 'gt_label' in data_sample:
|
|
gt_label = data_sample.gt_label
|
|
idx = gt_label.label.tolist()
|
|
class_labels = [''] * len(idx)
|
|
if classes is not None:
|
|
class_labels = [f' ({classes[i]})' for i in idx]
|
|
labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))]
|
|
prefix = 'Ground truth: '
|
|
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
|
|
|
|
if draw_pred and 'pred_label' in data_sample:
|
|
pred_label = data_sample.pred_label
|
|
idx = pred_label.label.tolist()
|
|
score_labels = [''] * len(idx)
|
|
class_labels = [''] * len(idx)
|
|
if draw_score and 'score' in pred_label:
|
|
score_labels = [
|
|
f', {pred_label.score[i].item():.2f}' for i in idx
|
|
]
|
|
|
|
if classes is not None:
|
|
class_labels = [f' ({classes[i]})' for i in idx]
|
|
|
|
labels = [
|
|
str(idx[i]) + score_labels[i] + class_labels[i]
|
|
for i in range(len(idx))
|
|
]
|
|
prefix = 'Prediction: '
|
|
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
|
|
|
|
img_scale = _get_adaptive_scale(image.shape[:2])
|
|
text_cfg = {
|
|
'positions': np.array([(img_scale * 5, ) * 2]).astype(np.int32),
|
|
'font_sizes': int(img_scale * 7),
|
|
'font_families': 'monospace',
|
|
'colors': 'white',
|
|
'bboxes': dict(facecolor='black', alpha=0.5, boxstyle='Round'),
|
|
**text_cfg
|
|
}
|
|
self.draw_texts('\n'.join(texts), **text_cfg)
|
|
drawn_img = self.get_image()
|
|
|
|
if show:
|
|
self.show(drawn_img, win_name=name, wait_time=wait_time)
|
|
|
|
if out_file is not None:
|
|
# save the image to the target file instead of vis_backends
|
|
mmcv.imwrite(drawn_img[..., ::-1], out_file)
|
|
else:
|
|
self.add_image(name, drawn_img, step=step)
|