mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* add visualizer unittest * update * fix comment * fix commit * fix comment * fix comment
285 lines
11 KiB
Python
285 lines
11 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from mmengine.visualizer import Visualizer
|
|
|
|
|
|
class TesVisualizer(TestCase):
|
|
|
|
def setUp(self):
|
|
"""Setup the demo image in every test method.
|
|
|
|
TestCase calls functions in this order: setUp() -> testMethod() ->
|
|
tearDown() -> cleanUp()
|
|
"""
|
|
self.image = np.random.randint(0, 256, size=(10, 10, 3))
|
|
|
|
def assert_img_equal(self, img, ref_img, ratio_thr=0.999):
|
|
assert img.shape == ref_img.shape
|
|
assert img.dtype == ref_img.dtype
|
|
area = ref_img.shape[0] * ref_img.shape[1]
|
|
diff = np.abs(img.astype('int32') - ref_img.astype('int32'))
|
|
assert np.sum(diff <= 1) / float(area) > ratio_thr
|
|
|
|
def test_init(self):
|
|
# test `scale` parameter
|
|
# `scale` must be greater than 0.
|
|
with pytest.raises(AssertionError):
|
|
Visualizer(scale=0)
|
|
|
|
visualizer = Visualizer(scale=2, image=self.image)
|
|
out_image = visualizer.get_image()
|
|
assert (20, 20, 3) == out_image.shape
|
|
|
|
def test_set_image(self):
|
|
visualizer = Visualizer()
|
|
visualizer.set_image(self.image)
|
|
assert self.assert_img_equal(self.image, visualizer.get_image())
|
|
# test grayscale image
|
|
visualizer.set_image(self.image[..., 0])
|
|
assert self.assert_img_equal(self.image[..., 0],
|
|
visualizer.get_image())
|
|
|
|
def test_get_image(self):
|
|
visualizer = Visualizer(image=self.image)
|
|
out_image = visualizer.get_image()
|
|
assert self.assert_img_equal(self.image, out_image)
|
|
|
|
def test_draw_bboxes(self):
|
|
visualizer = Visualizer(image=self.image)
|
|
|
|
# only support 4 or nx4 tensor and numpy
|
|
visualizer.draw_bboxes(torch.tensor([1, 1, 2, 2]))
|
|
# valid bbox
|
|
visualizer.draw_bboxes(torch.tensor([1, 1, 1, 2]))
|
|
bboxes = torch.tensor([[1, 1, 2, 2], [1, 2, 2, 2.5]])
|
|
visualizer.draw_bboxes(
|
|
bboxes, alpha=0.5, edge_color='b', line_style='-')
|
|
bboxes = bboxes.numpy()
|
|
visualizer.draw_bboxes(bboxes)
|
|
|
|
# test invalid bbox
|
|
with pytest.raises(AssertionError):
|
|
# x1 > x2
|
|
visualizer.draw_bboxes(torch.tensor([5, 1, 2, 2]))
|
|
|
|
# test out of bounds
|
|
with pytest.warns(
|
|
UserWarning,
|
|
match='Warning: The bbox is out of bounds,'
|
|
' the drawn bbox may not be in the image'):
|
|
visualizer.draw_bboxes(torch.tensor([1, 1, 20, 2]))
|
|
|
|
# test incorrect bbox format
|
|
with pytest.raises(AssertionError):
|
|
visualizer.draw_bboxes([1, 1, 2, 2])
|
|
|
|
def test_draw_texts(self):
|
|
visualizer = Visualizer(image=self.image)
|
|
|
|
# only support tensor and numpy
|
|
visualizer.draw_texts('text1', position=torch.tensor([5, 5]))
|
|
visualizer.draw_texts(['text1', 'text2'],
|
|
position=torch.tensor([[5, 5], [3, 3]]))
|
|
visualizer.draw_texts('text1', position=np.array([5, 5]))
|
|
visualizer.draw_texts(['text1', 'text2'],
|
|
position=np.array([[5, 5], [3, 3]]))
|
|
|
|
# test out of bounds
|
|
with pytest.warns(
|
|
UserWarning,
|
|
match='Warning: The text is out of bounds,'
|
|
' the drawn text may not be in the image'):
|
|
visualizer.draw_texts('text1', position=torch.tensor([15, 5]))
|
|
|
|
# test incorrect format
|
|
with pytest.raises(AssertionError):
|
|
visualizer.draw_texts('text', position=[5, 5])
|
|
|
|
# test length mismatch
|
|
with pytest.raises(AssertionError):
|
|
visualizer.draw_texts(['text1', 'text2'],
|
|
position=torch.tensor([5, 5]))
|
|
visualizer.draw_texts('text1', position=torch.tensor([[5, 5]]))
|
|
visualizer.draw_texts(
|
|
'text1', position=torch.tensor([[5, 5], [3, 3]]))
|
|
|
|
def test_draw_lines(self):
|
|
visualizer = Visualizer(image=self.image)
|
|
|
|
# only support tensor and numpy
|
|
visualizer.draw_lines(
|
|
x_datas=torch.tensor([1, 5]), y_datas=torch.tensor([2, 6]))
|
|
visualizer.draw_lines(
|
|
x_datas=np.array([1, 5, 4]), y_datas=np.array([2, 6, 6]))
|
|
|
|
# test out of bounds
|
|
with pytest.warns(
|
|
UserWarning,
|
|
match='Warning: The line is out of bounds,'
|
|
' the drawn line may not be in the image'):
|
|
visualizer.draw_lines(
|
|
x_datas=torch.tensor([12, 5]), y_datas=torch.tensor([2, 6]))
|
|
|
|
# test incorrect format
|
|
with pytest.raises(AssertionError):
|
|
visualizer.draw_texts('text', position=[5, 5])
|
|
|
|
# test length mismatch
|
|
with pytest.raises(AssertionError):
|
|
visualizer.draw_lines(
|
|
x_datas=torch.tensor([1, 5]), y_datas=torch.tensor([2, 6, 7]))
|
|
|
|
def test_draw_circles(self):
|
|
visualizer = Visualizer(image=self.image)
|
|
|
|
# only support tensor and numpy
|
|
visualizer.draw_circles(torch.tensor([1, 5]))
|
|
visualizer.draw_circles(np.array([1, 5]))
|
|
visualizer.draw_circles(
|
|
torch.tensor([[1, 5], [2, 6]]), radius=torch.tensor([1, 2]))
|
|
|
|
# test out of bounds
|
|
with pytest.warns(
|
|
UserWarning,
|
|
match='Warning: The circle is out of bounds,'
|
|
' the drawn circle may not be in the image'):
|
|
visualizer.draw_circles(torch.tensor([12, 5]))
|
|
visualizer.draw_circles(torch.tensor([1, 5]), radius=10)
|
|
|
|
# test incorrect format
|
|
with pytest.raises(AssertionError):
|
|
visualizer.draw_circles([1, 5])
|
|
|
|
# test length mismatch
|
|
with pytest.raises(AssertionError):
|
|
visualizer.draw_circles(
|
|
torch.tensor([[1, 5]]), radius=torch.tensor([1, 2]))
|
|
|
|
def test_draw_polygons(self):
|
|
visualizer = Visualizer(image=self.image)
|
|
# shape Nx2 or list[Nx2]
|
|
visualizer.draw_polygons(torch.tensor([[1, 1], [2, 2], [3, 4]]))
|
|
visualizer.draw_polygons(np.array([[1, 1], [2, 2], [3, 4]]))
|
|
visualizer.draw_polygons([
|
|
np.array([[1, 1], [2, 2], [3, 4]]),
|
|
torch.tensor([[1, 1], [2, 2], [3, 4]])
|
|
])
|
|
|
|
# test out of bounds
|
|
with pytest.warns(
|
|
UserWarning,
|
|
match='Warning: The polygon is out of bounds,'
|
|
' the drawn polygon may not be in the image'):
|
|
visualizer.draw_polygons(torch.tensor([[1, 1], [2, 2], [16, 4]]))
|
|
|
|
def test_draw_binary_masks(self):
|
|
binary_mask = np.random.randint(0, 2, size=(10, 10)).astype(np.bool)
|
|
visualizer = Visualizer(image=self.image)
|
|
visualizer.draw_binary_masks(binary_mask)
|
|
visualizer.draw_binary_masks(torch.from_numpy(binary_mask))
|
|
|
|
# test the error that the size of mask and image are different.
|
|
with pytest.raises(AssertionError):
|
|
binary_mask = np.random.randint(0, 2, size=(8, 10)).astype(np.bool)
|
|
visualizer.draw_binary_masks(binary_mask)
|
|
|
|
# test non binary mask error
|
|
binary_mask = np.random.randint(0, 2, size=(10, 10, 3)).astype(np.bool)
|
|
with pytest.raises(AssertionError):
|
|
visualizer.draw_binary_masks(binary_mask)
|
|
|
|
# test non bool error
|
|
binary_mask = np.random.randint(0, 2, size=(10, 10))
|
|
with pytest.raises(AssertionError):
|
|
visualizer.draw_binary_masks(binary_mask)
|
|
|
|
def test_draw_featmap(self):
|
|
visualizer = Visualizer()
|
|
|
|
# test tensor format
|
|
with pytest.raises(AssertionError, match='Input dimension must be 3'):
|
|
visualizer.draw_featmap(torch.randn(1, 1, 3, 3))
|
|
|
|
# test mode parameter
|
|
# mode only supports 'mean' and 'max'
|
|
with pytest.raises(AssertionError):
|
|
visualizer.draw_featmap(torch.randn(2, 3, 3), mode='min')
|
|
|
|
# test topk parameter
|
|
with pytest.raises(
|
|
AssertionError,
|
|
match='The input tensor channel dimension must be 1 or 3 '
|
|
'when topk is less than 1, but the channel '
|
|
'dimension you input is 6, you can use the '
|
|
'mode parameter or set topk greater than 0 to solve '
|
|
'the error'):
|
|
visualizer.draw_featmap(torch.randn(6, 3, 3), mode=None, topk=0)
|
|
|
|
visualizer.draw_featmap(torch.randn(6, 3, 3), mode='mean')
|
|
visualizer.draw_featmap(torch.randn(1, 3, 3), mode='mean')
|
|
visualizer.draw_featmap(torch.randn(6, 3, 3), mode='max')
|
|
visualizer.draw_featmap(torch.randn(6, 3, 3), mode='max', topk=10)
|
|
visualizer.draw_featmap(torch.randn(1, 3, 3), mode=None, topk=-1)
|
|
visualizer.draw_featmap(torch.randn(3, 3, 3), mode=None, topk=-1)
|
|
visualizer.draw_featmap(torch.randn(6, 3, 3), mode=None, topk=4)
|
|
visualizer.draw_featmap(torch.randn(6, 3, 3), mode=None, topk=8)
|
|
|
|
def test_chain_call(self):
|
|
visualizer = Visualizer(image=self.image)
|
|
binary_mask = np.random.randint(0, 2, size=(10, 10)).astype(np.bool)
|
|
visualizer.draw_bboxes(torch.tensor([1, 1, 2, 2])). \
|
|
draw_texts('test', torch.tensor([5, 5])). \
|
|
draw_lines(x_datas=torch.tensor([1, 5]),
|
|
y_datas=torch.tensor([2, 6])). \
|
|
draw_circles(torch.tensor([1, 5])). \
|
|
draw_polygons(torch.tensor([[1, 1], [2, 2], [3, 4]])). \
|
|
draw_binary_masks(binary_mask)
|
|
|
|
def test_register_task(self):
|
|
|
|
class DetVisualizer(Visualizer):
|
|
|
|
@Visualizer.register_task('instances')
|
|
def draw_instance(self, instances, data_type):
|
|
pass
|
|
|
|
assert len(Visualizer.task_dict) == 1
|
|
assert 'instances' in Visualizer.task_dict
|
|
|
|
# test registration of the same names.
|
|
with pytest.raises(
|
|
KeyError,
|
|
match=('"instances" is already registered in task_dict, '
|
|
'add "force=True" if you want to override it')):
|
|
|
|
class DetVisualizer1(Visualizer):
|
|
|
|
@Visualizer.register_task('instances')
|
|
def draw_instance1(self, instances, data_type):
|
|
pass
|
|
|
|
@Visualizer.register_task('instances')
|
|
def draw_instance2(self, instances, data_type):
|
|
pass
|
|
|
|
class DetVisualizer2(Visualizer):
|
|
|
|
@Visualizer.register_task('instances')
|
|
def draw_instance1(self, instances, data_type):
|
|
pass
|
|
|
|
@Visualizer.register_task('instances', force=True)
|
|
def draw_instance2(self, instances, data_type):
|
|
pass
|
|
|
|
det_visualizer = DetVisualizer2()
|
|
assert len(det_visualizer.task_dict) == 1
|
|
assert 'instances' in det_visualizer.task_dict
|
|
assert det_visualizer.task_dict[
|
|
'instances'].__name__ == 'draw_instance2'
|