# Copyright (c) Open-MMLab. All rights reserved. import os.path as osp import tempfile from unittest import TestCase from unittest.mock import patch import numpy as np import torch from mmpretrain.structures import DataSample from mmpretrain.visualization import ClsVisualizer class TestClsVisualizer(TestCase): def setUp(self) -> None: super().setUp() tmpdir = tempfile.TemporaryDirectory() self.tmpdir = tmpdir self.vis = ClsVisualizer( save_dir=tmpdir.name, vis_backends=[dict(type='LocalVisBackend')], ) def test_add_datasample(self): image = np.ones((10, 10, 3), np.uint8) data_sample = DataSample().set_gt_label(1).set_pred_label(1).\ set_pred_score(torch.tensor([0.1, 0.8, 0.1])) # Test show def mock_show(drawn_img, win_name, wait_time): self.assertFalse((image == drawn_img).all()) self.assertEqual(win_name, 'test') self.assertEqual(wait_time, 0) with patch.object(self.vis, 'show', mock_show): self.vis.add_datasample( 'test', image=image, data_sample=data_sample, show=True) # Test out_file out_file = osp.join(self.tmpdir.name, 'results.png') self.vis.add_datasample( 'test', image=image, data_sample=data_sample, out_file=out_file) self.assertTrue(osp.exists(out_file)) # Test storage backend. save_file = osp.join(self.tmpdir.name, 'vis_data/vis_image/test_0.png') self.assertTrue(osp.exists(save_file)) # Test with dataset_meta self.vis.dataset_meta = {'classes': ['cat', 'bird', 'dog']} def test_texts(text, *_, **__): self.assertEqual( text, '\n'.join([ 'Ground truth: 1 (bird)', 'Prediction: 1, 0.80 (bird)', ])) with patch.object(self.vis, 'draw_texts', test_texts): self.vis.add_datasample( 'test', image=image, data_sample=data_sample) # Test without pred_label def test_texts(text, *_, **__): self.assertEqual(text, '\n'.join([ 'Ground truth: 1 (bird)', ])) with patch.object(self.vis, 'draw_texts', test_texts): self.vis.add_datasample( 'test', image=image, data_sample=data_sample, draw_pred=False) # Test without gt_label def test_texts(text, *_, **__): self.assertEqual(text, '\n'.join([ 'Prediction: 1, 0.80 (bird)', ])) with patch.object(self.vis, 'draw_texts', test_texts): self.vis.add_datasample( 'test', image=image, data_sample=data_sample, draw_gt=False) # Test without score del data_sample.pred_score def test_texts(text, *_, **__): self.assertEqual( text, '\n'.join([ 'Ground truth: 1 (bird)', 'Prediction: 1 (bird)', ])) with patch.object(self.vis, 'draw_texts', test_texts): self.vis.add_datasample( 'test', image=image, data_sample=data_sample) # Test adaptive font size def assert_font_size(target_size): def draw_texts(text, font_sizes, *_, **__): self.assertEqual(font_sizes, target_size) return draw_texts with patch.object(self.vis, 'draw_texts', assert_font_size(7)): self.vis.add_datasample( 'test', image=np.ones((224, 384, 3), np.uint8), data_sample=data_sample) with patch.object(self.vis, 'draw_texts', assert_font_size(2)): self.vis.add_datasample( 'test', image=np.ones((10, 384, 3), np.uint8), data_sample=data_sample) with patch.object(self.vis, 'draw_texts', assert_font_size(21)): self.vis.add_datasample( 'test', image=np.ones((1000, 1000, 3), np.uint8), data_sample=data_sample) # Test rescale image with patch.object(self.vis, 'draw_texts', assert_font_size(14)): self.vis.add_datasample( 'test', image=np.ones((224, 384, 3), np.uint8), rescale_factor=2., data_sample=data_sample) def tearDown(self): self.tmpdir.cleanup()