mmclassification/tests/test_utils/test_visualization.py

107 lines
3.7 KiB
Python

# Copyright (c) Open-MMLab. All rights reserved.
import os
import os.path as osp
import shutil
import tempfile
from unittest.mock import Mock, patch
import matplotlib.pyplot as plt
import mmcv
import numpy as np
import pytest
from mmcls.core import visualization as vis
def test_color():
assert vis.color_val_matplotlib(mmcv.Color.blue) == (0., 0., 1.)
assert vis.color_val_matplotlib('green') == (0., 1., 0.)
assert vis.color_val_matplotlib((1, 2, 3)) == (3 / 255, 2 / 255, 1 / 255)
assert vis.color_val_matplotlib(100) == (100 / 255, 100 / 255, 100 / 255)
assert vis.color_val_matplotlib(np.zeros(3, dtype=int)) == (0., 0., 0.)
# forbid white color
with pytest.raises(TypeError):
vis.color_val_matplotlib([255, 255, 255])
# forbid float
with pytest.raises(TypeError):
vis.color_val_matplotlib(1.0)
# overflowed
with pytest.raises(AssertionError):
vis.color_val_matplotlib((0, 0, 500))
def test_imshow_infos():
tmp_dir = osp.join(tempfile.gettempdir(), 'image_infos')
tmp_filename = osp.join(tmp_dir, 'image.jpg')
image = np.ones((10, 10, 3), np.uint8)
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
out_image = vis.imshow_infos(
image, result, out_file=tmp_filename, show=False)
assert osp.isfile(tmp_filename)
assert image.shape == out_image.shape
assert not np.allclose(image, out_image)
os.remove(tmp_filename)
# test grayscale images
image = np.ones((10, 10), np.uint8)
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
out_image = vis.imshow_infos(
image, result, out_file=tmp_filename, show=False)
assert osp.isfile(tmp_filename)
assert image.shape == out_image.shape[:2]
os.remove(tmp_filename)
# test show=True
image = np.ones((10, 10, 3), np.uint8)
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
def mock_blocking_input(self, n=1, timeout=30):
keypress = Mock()
keypress.key = ' '
out_path = osp.join(tmp_dir, '_'.join([str(n), str(timeout)]))
with open(out_path, 'w') as f:
f.write('test')
return [keypress]
with patch('matplotlib.blocking_input.BlockingInput.__call__',
mock_blocking_input):
vis.imshow_infos(image, result, show=True, wait_time=5)
assert osp.exists(osp.join(tmp_dir, '1_0'))
shutil.rmtree(tmp_dir)
@patch(
'matplotlib.blocking_input.BlockingInput.__call__',
return_value=[Mock(key=' ')])
def test_context_manager(mock_blocking_input):
# test show multiple images with the same figure.
images = [
np.random.randint(0, 255, (100, 100, 3), np.uint8) for _ in range(5)
]
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
with vis.ImshowInfosContextManager() as manager:
fig_show = manager.fig_show
fig_save = manager.fig_save
for image in images:
out_image = manager.put_img_infos(image, result, show=True)
assert image.shape == out_image.shape
assert not np.allclose(image, out_image)
assert fig_show is manager.fig_show
assert fig_save is manager.fig_save
# test rebuild figure if user destroyed it.
with vis.ImshowInfosContextManager() as manager:
fig_save = manager.fig_save
for image in images:
fig_show = manager.fig_show
plt.close(manager.fig_show)
out_image = manager.put_img_infos(image, result, show=True)
assert image.shape == out_image.shape
assert not np.allclose(image, out_image)
assert not (fig_show is manager.fig_show)
assert fig_save is manager.fig_save