# Copyright (c) Open-MMLab. All rights reserved. import os import os.path as osp import shutil import tempfile from unittest.mock import Mock, patch import matplotlib.pyplot as plt import mmcv import numpy as np import pytest from mmcls.core import visualization as vis def test_color(): assert vis.color_val_matplotlib(mmcv.Color.blue) == (0., 0., 1.) assert vis.color_val_matplotlib('green') == (0., 1., 0.) assert vis.color_val_matplotlib((1, 2, 3)) == (3 / 255, 2 / 255, 1 / 255) assert vis.color_val_matplotlib(100) == (100 / 255, 100 / 255, 100 / 255) assert vis.color_val_matplotlib(np.zeros(3, dtype=int)) == (0., 0., 0.) # forbid white color with pytest.raises(TypeError): vis.color_val_matplotlib([255, 255, 255]) # forbid float with pytest.raises(TypeError): vis.color_val_matplotlib(1.0) # overflowed with pytest.raises(AssertionError): vis.color_val_matplotlib((0, 0, 500)) def test_imshow_infos(): tmp_dir = osp.join(tempfile.gettempdir(), 'image_infos') tmp_filename = osp.join(tmp_dir, 'image.jpg') image = np.ones((10, 10, 3), np.uint8) result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98} out_image = vis.imshow_infos( image, result, out_file=tmp_filename, show=False) assert osp.isfile(tmp_filename) assert image.shape == out_image.shape assert not np.allclose(image, out_image) os.remove(tmp_filename) # test grayscale images image = np.ones((10, 10), np.uint8) result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98} out_image = vis.imshow_infos( image, result, out_file=tmp_filename, show=False) assert osp.isfile(tmp_filename) assert image.shape == out_image.shape[:2] os.remove(tmp_filename) # test show=True image = np.ones((10, 10, 3), np.uint8) result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98} def mock_blocking_input(self, n=1, timeout=30): keypress = Mock() keypress.key = ' ' out_path = osp.join(tmp_dir, '_'.join([str(n), str(timeout)])) with open(out_path, 'w') as f: f.write('test') return [keypress] with patch('matplotlib.blocking_input.BlockingInput.__call__', mock_blocking_input): vis.imshow_infos(image, result, show=True, wait_time=5) assert osp.exists(osp.join(tmp_dir, '1_0')) shutil.rmtree(tmp_dir) @patch( 'matplotlib.blocking_input.BlockingInput.__call__', return_value=[Mock(key=' ')]) def test_context_manager(mock_blocking_input): # test show multiple images with the same figure. images = [ np.random.randint(0, 255, (100, 100, 3), np.uint8) for _ in range(5) ] result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98} with vis.ImshowInfosContextManager() as manager: fig_show = manager.fig_show fig_save = manager.fig_save for image in images: out_image = manager.put_img_infos(image, result, show=True) assert image.shape == out_image.shape assert not np.allclose(image, out_image) assert fig_show is manager.fig_show assert fig_save is manager.fig_save # test rebuild figure if user destroyed it. with vis.ImshowInfosContextManager() as manager: fig_save = manager.fig_save for image in images: fig_show = manager.fig_show plt.close(manager.fig_show) out_image = manager.put_img_infos(image, result, show=True) assert image.shape == out_image.shape assert not np.allclose(image, out_image) assert not (fig_show is manager.fig_show) assert fig_save is manager.fig_save