pull/1702/merge
haofeng 2024-01-06 16:22:57 +08:00 committed by GitHub
commit f5ebe3383b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 105 additions and 28 deletions

View File

@ -11,7 +11,7 @@ from mmengine.visualization import Visualizer
from mmengine.visualization.utils import img_from_canvas from mmengine.visualization.utils import img_from_canvas
from mmpretrain.registry import VISUALIZERS from mmpretrain.registry import VISUALIZERS
from mmpretrain.structures import DataSample from mmpretrain.structures import DataSample, MultiTaskDataSample
from .utils import create_figure, get_adaptive_scale from .utils import create_figure, get_adaptive_scale
@ -99,6 +99,67 @@ class UniversalVisualizer(Visualizer):
Returns: Returns:
np.ndarray: The visualization image. np.ndarray: The visualization image.
""" """
def _draw_gt(data_sample: DataSample,
classes: Optional[Sequence[str]],
draw_gt: bool,
texts: Sequence[str],
parent_task: str = ''):
if isinstance(data_sample, MultiTaskDataSample):
for task in data_sample.tasks:
sub_task = f'{parent_task}_{task}' if parent_task else task
_draw_gt(
data_sample.get(task), classes, draw_gt, texts,
sub_task)
else:
if draw_gt and 'gt_label' in data_sample:
idx = data_sample.gt_label.tolist()
class_labels = [''] * len(idx)
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [
str(idx[i]) + class_labels[i] for i in range(len(idx))
]
prefix = f'{parent_task} Ground truth: ' if parent_task \
else 'Ground truth: '
texts.append(prefix +
('\n' + ' ' * len(prefix)).join(labels))
def _draw_pred(data_sample: DataSample,
classes: Optional[Sequence[str]],
draw_pred: bool,
draw_score: bool,
texts: Sequence[str],
parent_task: str = ''):
if isinstance(data_sample, MultiTaskDataSample):
for task in data_sample.tasks:
sub_task = f'{parent_task}_{task}' if parent_task else task
_draw_pred(
data_sample.get(task), classes, draw_pred, draw_score,
texts, sub_task)
else:
if draw_pred and 'pred_label' in data_sample:
idx = data_sample.pred_label.tolist()
score_labels = [''] * len(idx)
class_labels = [''] * len(idx)
if draw_score and 'pred_score' in data_sample:
score_labels = [
f', {data_sample.pred_score[i].item():.2f}'
for i in idx
]
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [
str(idx[i]) + score_labels[i] + class_labels[i]
for i in range(len(idx))
]
prefix = f'{parent_task} Prediction: ' if parent_task \
else 'Prediction: '
texts.append(prefix +
('\n' + ' ' * len(prefix)).join(labels))
if self.dataset_meta is not None: if self.dataset_meta is not None:
classes = classes or self.dataset_meta.get('classes', None) classes = classes or self.dataset_meta.get('classes', None)
@ -114,33 +175,9 @@ class UniversalVisualizer(Visualizer):
texts = [] texts = []
self.set_image(image) self.set_image(image)
if draw_gt and 'gt_label' in data_sample: _draw_gt(data_sample, classes, draw_gt, texts)
idx = data_sample.gt_label.tolist()
class_labels = [''] * len(idx)
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))]
prefix = 'Ground truth: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
if draw_pred and 'pred_label' in data_sample: _draw_pred(data_sample, classes, draw_pred, draw_score, texts)
idx = data_sample.pred_label.tolist()
score_labels = [''] * len(idx)
class_labels = [''] * len(idx)
if draw_score and 'pred_score' in data_sample:
score_labels = [
f', {data_sample.pred_score[i].item():.2f}' for i in idx
]
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [
str(idx[i]) + score_labels[i] + class_labels[i]
for i in range(len(idx))
]
prefix = 'Prediction: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
img_scale = get_adaptive_scale(image.shape[:2]) img_scale = get_adaptive_scale(image.shape[:2])
text_cfg = { text_cfg = {

View File

@ -7,7 +7,7 @@ from unittest.mock import patch
import numpy as np import numpy as np
import torch import torch
from mmpretrain.structures import DataSample from mmpretrain.structures import DataSample, MultiTaskDataSample
from mmpretrain.visualization import UniversalVisualizer from mmpretrain.visualization import UniversalVisualizer
@ -123,6 +123,46 @@ class TestUniversalVisualizer(TestCase):
data_sample, data_sample,
rescale_factor=2.) rescale_factor=2.)
def test_visualize_multitask_cls(self):
image = np.ones((1000, 1000, 3), np.uint8)
gt_label = {'task0': {'task00': 2, 'task01': 1}, 'task1': 1}
data_sample = MultiTaskDataSample()
task_sample = DataSample().set_gt_label(
gt_label['task1']).set_pred_label(1).set_pred_score(
torch.tensor([0.1, 0.8, 0.1]))
data_sample.set_field(task_sample, 'task1')
data_sample.set_field(MultiTaskDataSample(), 'task0')
for task_name in gt_label['task0']:
task_sample = DataSample().set_gt_label(
gt_label['task0'][task_name]).set_pred_label(2).set_pred_score(
torch.tensor([0.1, 0.4, 0.5]))
data_sample.task0.set_field(task_sample, task_name)
# Test show
def mock_show(drawn_img, win_name, wait_time):
self.assertFalse((image == drawn_img).all())
self.assertEqual(win_name, 'test_cls')
self.assertEqual(wait_time, 0)
with patch.object(self.vis, 'show', mock_show):
self.vis.visualize_cls(
image=image,
data_sample=data_sample,
show=True,
name='test_cls',
step=2)
# Test storage backend.
save_file = osp.join(self.tmpdir.name,
'vis_data/vis_image/test_cls_2.png')
self.assertTrue(osp.exists(save_file))
# Test out_file
out_file = osp.join(self.tmpdir.name, 'results_2.png')
self.vis.visualize_cls(
image=image, data_sample=data_sample, out_file=out_file)
self.assertTrue(osp.exists(out_file))
def test_visualize_image_retrieval(self): def test_visualize_image_retrieval(self):
image = np.ones((10, 10, 3), np.uint8) image = np.ones((10, 10, 3), np.uint8)
data_sample = DataSample().set_pred_score([0.1, 0.8, 0.1]) data_sample = DataSample().set_pred_score([0.1, 0.8, 0.1])