201 lines
6.7 KiB
Python
201 lines
6.7 KiB
Python
# Copyright (c) Open-MMLab. All rights reserved.
|
|
import os.path as osp
|
|
import tempfile
|
|
from unittest import TestCase
|
|
from unittest.mock import patch
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from mmpretrain.structures import DataSample
|
|
from mmpretrain.visualization import UniversalVisualizer
|
|
|
|
|
|
class TestUniversalVisualizer(TestCase):
|
|
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
tmpdir = tempfile.TemporaryDirectory()
|
|
self.tmpdir = tmpdir
|
|
self.vis = UniversalVisualizer(
|
|
save_dir=tmpdir.name,
|
|
vis_backends=[dict(type='LocalVisBackend')],
|
|
)
|
|
|
|
def test_visualize_cls(self):
|
|
image = np.ones((10, 10, 3), np.uint8)
|
|
data_sample = DataSample().set_gt_label(1).set_pred_label(1).\
|
|
set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
|
|
|
|
# Test show
|
|
def mock_show(drawn_img, win_name, wait_time):
|
|
self.assertFalse((image == drawn_img).all())
|
|
self.assertEqual(win_name, 'test_cls')
|
|
self.assertEqual(wait_time, 0)
|
|
|
|
with patch.object(self.vis, 'show', mock_show):
|
|
self.vis.visualize_cls(
|
|
image=image,
|
|
data_sample=data_sample,
|
|
show=True,
|
|
name='test_cls',
|
|
step=1)
|
|
|
|
# Test storage backend.
|
|
save_file = osp.join(self.tmpdir.name,
|
|
'vis_data/vis_image/test_cls_1.png')
|
|
self.assertTrue(osp.exists(save_file))
|
|
|
|
# Test out_file
|
|
out_file = osp.join(self.tmpdir.name, 'results.png')
|
|
self.vis.visualize_cls(
|
|
image=image, data_sample=data_sample, out_file=out_file)
|
|
self.assertTrue(osp.exists(out_file))
|
|
|
|
# Test with dataset_meta
|
|
self.vis.dataset_meta = {'classes': ['cat', 'bird', 'dog']}
|
|
|
|
def patch_texts(text, *_, **__):
|
|
self.assertEqual(
|
|
text, '\n'.join([
|
|
'Ground truth: 1 (bird)',
|
|
'Prediction: 1, 0.80 (bird)',
|
|
]))
|
|
|
|
with patch.object(self.vis, 'draw_texts', patch_texts):
|
|
self.vis.visualize_cls(image, data_sample)
|
|
|
|
# Test without pred_label
|
|
def patch_texts(text, *_, **__):
|
|
self.assertEqual(text, '\n'.join([
|
|
'Ground truth: 1 (bird)',
|
|
]))
|
|
|
|
with patch.object(self.vis, 'draw_texts', patch_texts):
|
|
self.vis.visualize_cls(image, data_sample, draw_pred=False)
|
|
|
|
# Test without gt_label
|
|
def patch_texts(text, *_, **__):
|
|
self.assertEqual(text, '\n'.join([
|
|
'Prediction: 1, 0.80 (bird)',
|
|
]))
|
|
|
|
with patch.object(self.vis, 'draw_texts', patch_texts):
|
|
self.vis.visualize_cls(image, data_sample, draw_gt=False)
|
|
|
|
# Test without score
|
|
del data_sample.pred_score
|
|
|
|
def patch_texts(text, *_, **__):
|
|
self.assertEqual(
|
|
text, '\n'.join([
|
|
'Ground truth: 1 (bird)',
|
|
'Prediction: 1 (bird)',
|
|
]))
|
|
|
|
with patch.object(self.vis, 'draw_texts', patch_texts):
|
|
self.vis.visualize_cls(image, data_sample)
|
|
|
|
# Test adaptive font size
|
|
def assert_font_size(target_size):
|
|
|
|
def draw_texts(text, font_sizes, *_, **__):
|
|
self.assertEqual(font_sizes, target_size)
|
|
|
|
return draw_texts
|
|
|
|
with patch.object(self.vis, 'draw_texts', assert_font_size(7)):
|
|
self.vis.visualize_cls(
|
|
np.ones((224, 384, 3), np.uint8), data_sample)
|
|
|
|
with patch.object(self.vis, 'draw_texts', assert_font_size(2)):
|
|
self.vis.visualize_cls(
|
|
np.ones((10, 384, 3), np.uint8), data_sample)
|
|
|
|
with patch.object(self.vis, 'draw_texts', assert_font_size(21)):
|
|
self.vis.visualize_cls(
|
|
np.ones((1000, 1000, 3), np.uint8), data_sample)
|
|
|
|
# Test rescale image
|
|
with patch.object(self.vis, 'draw_texts', assert_font_size(14)):
|
|
self.vis.visualize_cls(
|
|
np.ones((224, 384, 3), np.uint8),
|
|
data_sample,
|
|
rescale_factor=2.)
|
|
|
|
def test_visualize_image_retrieval(self):
|
|
image = np.ones((10, 10, 3), np.uint8)
|
|
data_sample = DataSample().set_pred_score([0.1, 0.8, 0.1])
|
|
|
|
class ToyPrototype:
|
|
|
|
def get_data_info(self, idx):
|
|
img_path = osp.join(osp.dirname(__file__), '../data/color.jpg')
|
|
return {'img_path': img_path, 'sample_idx': idx}
|
|
|
|
prototype_dataset = ToyPrototype()
|
|
|
|
# Test show
|
|
def mock_show(drawn_img, win_name, wait_time):
|
|
if image.shape == drawn_img.shape:
|
|
self.assertFalse((image == drawn_img).all())
|
|
self.assertEqual(win_name, 'test_retrieval')
|
|
self.assertEqual(wait_time, 0)
|
|
|
|
with patch.object(self.vis, 'show', mock_show):
|
|
self.vis.visualize_image_retrieval(
|
|
image,
|
|
data_sample,
|
|
prototype_dataset,
|
|
show=True,
|
|
name='test_retrieval',
|
|
step=1)
|
|
|
|
# Test storage backend.
|
|
save_file = osp.join(self.tmpdir.name,
|
|
'vis_data/vis_image/test_retrieval_1.png')
|
|
self.assertTrue(osp.exists(save_file))
|
|
|
|
# Test out_file
|
|
out_file = osp.join(self.tmpdir.name, 'results.png')
|
|
self.vis.visualize_image_retrieval(
|
|
image,
|
|
data_sample,
|
|
prototype_dataset,
|
|
out_file=out_file,
|
|
)
|
|
self.assertTrue(osp.exists(out_file))
|
|
|
|
def test_visualize_masked_image(self):
|
|
image = np.ones((10, 10, 3), np.uint8)
|
|
data_sample = DataSample().set_mask(
|
|
torch.tensor([
|
|
[0, 0, 1, 1],
|
|
[0, 1, 1, 0],
|
|
[1, 1, 0, 0],
|
|
[1, 0, 0, 1],
|
|
]))
|
|
|
|
# Test show
|
|
def mock_show(drawn_img, win_name, wait_time):
|
|
self.assertTupleEqual(drawn_img.shape, (224, 224, 3))
|
|
self.assertEqual(win_name, 'test_mask')
|
|
self.assertEqual(wait_time, 0)
|
|
|
|
with patch.object(self.vis, 'show', mock_show):
|
|
self.vis.visualize_masked_image(
|
|
image, data_sample, show=True, name='test_mask', step=1)
|
|
|
|
# Test storage backend.
|
|
save_file = osp.join(self.tmpdir.name,
|
|
'vis_data/vis_image/test_mask_1.png')
|
|
self.assertTrue(osp.exists(save_file))
|
|
|
|
# Test out_file
|
|
out_file = osp.join(self.tmpdir.name, 'results.png')
|
|
self.vis.visualize_masked_image(image, data_sample, out_file=out_file)
|
|
self.assertTrue(osp.exists(out_file))
|
|
|
|
def tearDown(self):
|
|
self.tmpdir.cleanup()
|