[feature] Visualizer compatible with MultiTaskDataSample

pull/1702/head
John Son 2023-07-10 08:46:47 +08:00
parent 7d850dfadd
commit 936c2a4966
2 changed files with 105 additions and 28 deletions

View File

@ -11,7 +11,7 @@ from mmengine.visualization import Visualizer
from mmengine.visualization.utils import img_from_canvas
from mmpretrain.registry import VISUALIZERS
from mmpretrain.structures import DataSample
from mmpretrain.structures import DataSample, MultiTaskDataSample
from .utils import create_figure, get_adaptive_scale
@ -99,6 +99,67 @@ class UniversalVisualizer(Visualizer):
Returns:
np.ndarray: The visualization image.
"""
def _draw_gt(data_sample: DataSample,
classes: Optional[Sequence[str]],
draw_gt: bool,
texts: Sequence[str],
parent_task: str = ''):
if isinstance(data_sample, MultiTaskDataSample):
for task in data_sample.tasks:
sub_task = f'{parent_task}_{task}' if parent_task else task
_draw_gt(
data_sample.get(task), classes, draw_gt, texts,
sub_task)
else:
if draw_gt and 'gt_label' in data_sample:
idx = data_sample.gt_label.tolist()
class_labels = [''] * len(idx)
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [
str(idx[i]) + class_labels[i] for i in range(len(idx))
]
prefix = f'{parent_task} Ground truth: ' if parent_task \
else 'Ground truth: '
texts.append(prefix +
('\n' + ' ' * len(prefix)).join(labels))
def _draw_pred(data_sample: DataSample,
classes: Optional[Sequence[str]],
draw_pred: bool,
draw_score: bool,
texts: Sequence[str],
parent_task: str = ''):
if isinstance(data_sample, MultiTaskDataSample):
for task in data_sample.tasks:
sub_task = f'{parent_task}_{task}' if parent_task else task
_draw_pred(
data_sample.get(task), classes, draw_pred, draw_score,
texts, sub_task)
else:
if draw_pred and 'pred_label' in data_sample:
idx = data_sample.pred_label.tolist()
score_labels = [''] * len(idx)
class_labels = [''] * len(idx)
if draw_score and 'pred_score' in data_sample:
score_labels = [
f', {data_sample.pred_score[i].item():.2f}'
for i in idx
]
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [
str(idx[i]) + score_labels[i] + class_labels[i]
for i in range(len(idx))
]
prefix = f'{parent_task} Prediction: ' if parent_task \
else 'Prediction: '
texts.append(prefix +
('\n' + ' ' * len(prefix)).join(labels))
if self.dataset_meta is not None:
classes = classes or self.dataset_meta.get('classes', None)
@ -114,33 +175,9 @@ class UniversalVisualizer(Visualizer):
texts = []
self.set_image(image)
if draw_gt and 'gt_label' in data_sample:
idx = data_sample.gt_label.tolist()
class_labels = [''] * len(idx)
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))]
prefix = 'Ground truth: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
_draw_gt(data_sample, classes, draw_gt, texts)
if draw_pred and 'pred_label' in data_sample:
idx = data_sample.pred_label.tolist()
score_labels = [''] * len(idx)
class_labels = [''] * len(idx)
if draw_score and 'pred_score' in data_sample:
score_labels = [
f', {data_sample.pred_score[i].item():.2f}' for i in idx
]
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [
str(idx[i]) + score_labels[i] + class_labels[i]
for i in range(len(idx))
]
prefix = 'Prediction: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
_draw_pred(data_sample, classes, draw_pred, draw_score, texts)
img_scale = get_adaptive_scale(image.shape[:2])
text_cfg = {

View File

@ -7,7 +7,7 @@ from unittest.mock import patch
import numpy as np
import torch
from mmpretrain.structures import DataSample
from mmpretrain.structures import DataSample, MultiTaskDataSample
from mmpretrain.visualization import UniversalVisualizer
@ -123,6 +123,46 @@ class TestUniversalVisualizer(TestCase):
data_sample,
rescale_factor=2.)
def test_visualize_multitask_cls(self):
image = np.ones((1000, 1000, 3), np.uint8)
gt_label = {'task0': {'task00': 2, 'task01': 1}, 'task1': 1}
data_sample = MultiTaskDataSample()
task_sample = DataSample().set_gt_label(
gt_label['task1']).set_pred_label(1).set_pred_score(
torch.tensor([0.1, 0.8, 0.1]))
data_sample.set_field(task_sample, 'task1')
data_sample.set_field(MultiTaskDataSample(), 'task0')
for task_name in gt_label['task0']:
task_sample = DataSample().set_gt_label(
gt_label['task0'][task_name]).set_pred_label(2).set_pred_score(
torch.tensor([0.1, 0.4, 0.5]))
data_sample.task0.set_field(task_sample, task_name)
# Test show
def mock_show(drawn_img, win_name, wait_time):
self.assertFalse((image == drawn_img).all())
self.assertEqual(win_name, 'test_cls')
self.assertEqual(wait_time, 0)
with patch.object(self.vis, 'show', mock_show):
self.vis.visualize_cls(
image=image,
data_sample=data_sample,
show=True,
name='test_cls',
step=2)
# Test storage backend.
save_file = osp.join(self.tmpdir.name,
'vis_data/vis_image/test_cls_2.png')
self.assertTrue(osp.exists(save_file))
# Test out_file
out_file = osp.join(self.tmpdir.name, 'results_2.png')
self.vis.visualize_cls(
image=image, data_sample=data_sample, out_file=out_file)
self.assertTrue(osp.exists(out_file))
def test_visualize_image_retrieval(self):
image = np.ones((10, 10, 3), np.uint8)
data_sample = DataSample().set_pred_score([0.1, 0.8, 0.1])