mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Unittest] Add visualizer unittest (#27)
* add visualizer unittest * update * fix comment * fix commit * fix comment * fix comment
This commit is contained in:
parent
f0451a38f0
commit
adb2aee8c2
284
tests/test_visualizer/test_visualizer.py
Normal file
284
tests/test_visualizer/test_visualizer.py
Normal file
@ -0,0 +1,284 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.visualizer import Visualizer
|
||||
|
||||
|
||||
class TesVisualizer(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Setup the demo image in every test method.
|
||||
|
||||
TestCase calls functions in this order: setUp() -> testMethod() ->
|
||||
tearDown() -> cleanUp()
|
||||
"""
|
||||
self.image = np.random.randint(0, 256, size=(10, 10, 3))
|
||||
|
||||
def assert_img_equal(self, img, ref_img, ratio_thr=0.999):
|
||||
assert img.shape == ref_img.shape
|
||||
assert img.dtype == ref_img.dtype
|
||||
area = ref_img.shape[0] * ref_img.shape[1]
|
||||
diff = np.abs(img.astype('int32') - ref_img.astype('int32'))
|
||||
assert np.sum(diff <= 1) / float(area) > ratio_thr
|
||||
|
||||
def test_init(self):
|
||||
# test `scale` parameter
|
||||
# `scale` must be greater than 0.
|
||||
with pytest.raises(AssertionError):
|
||||
Visualizer(scale=0)
|
||||
|
||||
visualizer = Visualizer(scale=2, image=self.image)
|
||||
out_image = visualizer.get_image()
|
||||
assert (20, 20, 3) == out_image.shape
|
||||
|
||||
def test_set_image(self):
|
||||
visualizer = Visualizer()
|
||||
visualizer.set_image(self.image)
|
||||
assert self.assert_img_equal(self.image, visualizer.get_image())
|
||||
# test grayscale image
|
||||
visualizer.set_image(self.image[..., 0])
|
||||
assert self.assert_img_equal(self.image[..., 0],
|
||||
visualizer.get_image())
|
||||
|
||||
def test_get_image(self):
|
||||
visualizer = Visualizer(image=self.image)
|
||||
out_image = visualizer.get_image()
|
||||
assert self.assert_img_equal(self.image, out_image)
|
||||
|
||||
def test_draw_bboxes(self):
|
||||
visualizer = Visualizer(image=self.image)
|
||||
|
||||
# only support 4 or nx4 tensor and numpy
|
||||
visualizer.draw_bboxes(torch.tensor([1, 1, 2, 2]))
|
||||
# valid bbox
|
||||
visualizer.draw_bboxes(torch.tensor([1, 1, 1, 2]))
|
||||
bboxes = torch.tensor([[1, 1, 2, 2], [1, 2, 2, 2.5]])
|
||||
visualizer.draw_bboxes(
|
||||
bboxes, alpha=0.5, edge_color='b', line_style='-')
|
||||
bboxes = bboxes.numpy()
|
||||
visualizer.draw_bboxes(bboxes)
|
||||
|
||||
# test invalid bbox
|
||||
with pytest.raises(AssertionError):
|
||||
# x1 > x2
|
||||
visualizer.draw_bboxes(torch.tensor([5, 1, 2, 2]))
|
||||
|
||||
# test out of bounds
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match='Warning: The bbox is out of bounds,'
|
||||
' the drawn bbox may not be in the image'):
|
||||
visualizer.draw_bboxes(torch.tensor([1, 1, 20, 2]))
|
||||
|
||||
# test incorrect bbox format
|
||||
with pytest.raises(AssertionError):
|
||||
visualizer.draw_bboxes([1, 1, 2, 2])
|
||||
|
||||
def test_draw_texts(self):
|
||||
visualizer = Visualizer(image=self.image)
|
||||
|
||||
# only support tensor and numpy
|
||||
visualizer.draw_texts('text1', position=torch.tensor([5, 5]))
|
||||
visualizer.draw_texts(['text1', 'text2'],
|
||||
position=torch.tensor([[5, 5], [3, 3]]))
|
||||
visualizer.draw_texts('text1', position=np.array([5, 5]))
|
||||
visualizer.draw_texts(['text1', 'text2'],
|
||||
position=np.array([[5, 5], [3, 3]]))
|
||||
|
||||
# test out of bounds
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match='Warning: The text is out of bounds,'
|
||||
' the drawn text may not be in the image'):
|
||||
visualizer.draw_texts('text1', position=torch.tensor([15, 5]))
|
||||
|
||||
# test incorrect format
|
||||
with pytest.raises(AssertionError):
|
||||
visualizer.draw_texts('text', position=[5, 5])
|
||||
|
||||
# test length mismatch
|
||||
with pytest.raises(AssertionError):
|
||||
visualizer.draw_texts(['text1', 'text2'],
|
||||
position=torch.tensor([5, 5]))
|
||||
visualizer.draw_texts('text1', position=torch.tensor([[5, 5]]))
|
||||
visualizer.draw_texts(
|
||||
'text1', position=torch.tensor([[5, 5], [3, 3]]))
|
||||
|
||||
def test_draw_lines(self):
|
||||
visualizer = Visualizer(image=self.image)
|
||||
|
||||
# only support tensor and numpy
|
||||
visualizer.draw_lines(
|
||||
x_datas=torch.tensor([1, 5]), y_datas=torch.tensor([2, 6]))
|
||||
visualizer.draw_lines(
|
||||
x_datas=np.array([1, 5, 4]), y_datas=np.array([2, 6, 6]))
|
||||
|
||||
# test out of bounds
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match='Warning: The line is out of bounds,'
|
||||
' the drawn line may not be in the image'):
|
||||
visualizer.draw_lines(
|
||||
x_datas=torch.tensor([12, 5]), y_datas=torch.tensor([2, 6]))
|
||||
|
||||
# test incorrect format
|
||||
with pytest.raises(AssertionError):
|
||||
visualizer.draw_texts('text', position=[5, 5])
|
||||
|
||||
# test length mismatch
|
||||
with pytest.raises(AssertionError):
|
||||
visualizer.draw_lines(
|
||||
x_datas=torch.tensor([1, 5]), y_datas=torch.tensor([2, 6, 7]))
|
||||
|
||||
def test_draw_circles(self):
|
||||
visualizer = Visualizer(image=self.image)
|
||||
|
||||
# only support tensor and numpy
|
||||
visualizer.draw_circles(torch.tensor([1, 5]))
|
||||
visualizer.draw_circles(np.array([1, 5]))
|
||||
visualizer.draw_circles(
|
||||
torch.tensor([[1, 5], [2, 6]]), radius=torch.tensor([1, 2]))
|
||||
|
||||
# test out of bounds
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match='Warning: The circle is out of bounds,'
|
||||
' the drawn circle may not be in the image'):
|
||||
visualizer.draw_circles(torch.tensor([12, 5]))
|
||||
visualizer.draw_circles(torch.tensor([1, 5]), radius=10)
|
||||
|
||||
# test incorrect format
|
||||
with pytest.raises(AssertionError):
|
||||
visualizer.draw_circles([1, 5])
|
||||
|
||||
# test length mismatch
|
||||
with pytest.raises(AssertionError):
|
||||
visualizer.draw_circles(
|
||||
torch.tensor([[1, 5]]), radius=torch.tensor([1, 2]))
|
||||
|
||||
def test_draw_polygons(self):
|
||||
visualizer = Visualizer(image=self.image)
|
||||
# shape Nx2 or list[Nx2]
|
||||
visualizer.draw_polygons(torch.tensor([[1, 1], [2, 2], [3, 4]]))
|
||||
visualizer.draw_polygons(np.array([[1, 1], [2, 2], [3, 4]]))
|
||||
visualizer.draw_polygons([
|
||||
np.array([[1, 1], [2, 2], [3, 4]]),
|
||||
torch.tensor([[1, 1], [2, 2], [3, 4]])
|
||||
])
|
||||
|
||||
# test out of bounds
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match='Warning: The polygon is out of bounds,'
|
||||
' the drawn polygon may not be in the image'):
|
||||
visualizer.draw_polygons(torch.tensor([[1, 1], [2, 2], [16, 4]]))
|
||||
|
||||
def test_draw_binary_masks(self):
|
||||
binary_mask = np.random.randint(0, 2, size=(10, 10)).astype(np.bool)
|
||||
visualizer = Visualizer(image=self.image)
|
||||
visualizer.draw_binary_masks(binary_mask)
|
||||
visualizer.draw_binary_masks(torch.from_numpy(binary_mask))
|
||||
|
||||
# test the error that the size of mask and image are different.
|
||||
with pytest.raises(AssertionError):
|
||||
binary_mask = np.random.randint(0, 2, size=(8, 10)).astype(np.bool)
|
||||
visualizer.draw_binary_masks(binary_mask)
|
||||
|
||||
# test non binary mask error
|
||||
binary_mask = np.random.randint(0, 2, size=(10, 10, 3)).astype(np.bool)
|
||||
with pytest.raises(AssertionError):
|
||||
visualizer.draw_binary_masks(binary_mask)
|
||||
|
||||
# test non bool error
|
||||
binary_mask = np.random.randint(0, 2, size=(10, 10))
|
||||
with pytest.raises(AssertionError):
|
||||
visualizer.draw_binary_masks(binary_mask)
|
||||
|
||||
def test_draw_featmap(self):
|
||||
visualizer = Visualizer()
|
||||
|
||||
# test tensor format
|
||||
with pytest.raises(AssertionError, match='Input dimension must be 3'):
|
||||
visualizer.draw_featmap(torch.randn(1, 1, 3, 3))
|
||||
|
||||
# test mode parameter
|
||||
# mode only supports 'mean' and 'max'
|
||||
with pytest.raises(AssertionError):
|
||||
visualizer.draw_featmap(torch.randn(2, 3, 3), mode='min')
|
||||
|
||||
# test topk parameter
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match='The input tensor channel dimension must be 1 or 3 '
|
||||
'when topk is less than 1, but the channel '
|
||||
'dimension you input is 6, you can use the '
|
||||
'mode parameter or set topk greater than 0 to solve '
|
||||
'the error'):
|
||||
visualizer.draw_featmap(torch.randn(6, 3, 3), mode=None, topk=0)
|
||||
|
||||
visualizer.draw_featmap(torch.randn(6, 3, 3), mode='mean')
|
||||
visualizer.draw_featmap(torch.randn(1, 3, 3), mode='mean')
|
||||
visualizer.draw_featmap(torch.randn(6, 3, 3), mode='max')
|
||||
visualizer.draw_featmap(torch.randn(6, 3, 3), mode='max', topk=10)
|
||||
visualizer.draw_featmap(torch.randn(1, 3, 3), mode=None, topk=-1)
|
||||
visualizer.draw_featmap(torch.randn(3, 3, 3), mode=None, topk=-1)
|
||||
visualizer.draw_featmap(torch.randn(6, 3, 3), mode=None, topk=4)
|
||||
visualizer.draw_featmap(torch.randn(6, 3, 3), mode=None, topk=8)
|
||||
|
||||
def test_chain_call(self):
|
||||
visualizer = Visualizer(image=self.image)
|
||||
binary_mask = np.random.randint(0, 2, size=(10, 10)).astype(np.bool)
|
||||
visualizer.draw_bboxes(torch.tensor([1, 1, 2, 2])). \
|
||||
draw_texts('test', torch.tensor([5, 5])). \
|
||||
draw_lines(x_datas=torch.tensor([1, 5]),
|
||||
y_datas=torch.tensor([2, 6])). \
|
||||
draw_circles(torch.tensor([1, 5])). \
|
||||
draw_polygons(torch.tensor([[1, 1], [2, 2], [3, 4]])). \
|
||||
draw_binary_masks(binary_mask)
|
||||
|
||||
def test_register_task(self):
|
||||
|
||||
class DetVisualizer(Visualizer):
|
||||
|
||||
@Visualizer.register_task('instances')
|
||||
def draw_instance(self, instances, data_type):
|
||||
pass
|
||||
|
||||
assert len(Visualizer.task_dict) == 1
|
||||
assert 'instances' in Visualizer.task_dict
|
||||
|
||||
# test registration of the same names.
|
||||
with pytest.raises(
|
||||
KeyError,
|
||||
match=('"instances" is already registered in task_dict, '
|
||||
'add "force=True" if you want to override it')):
|
||||
|
||||
class DetVisualizer1(Visualizer):
|
||||
|
||||
@Visualizer.register_task('instances')
|
||||
def draw_instance1(self, instances, data_type):
|
||||
pass
|
||||
|
||||
@Visualizer.register_task('instances')
|
||||
def draw_instance2(self, instances, data_type):
|
||||
pass
|
||||
|
||||
class DetVisualizer2(Visualizer):
|
||||
|
||||
@Visualizer.register_task('instances')
|
||||
def draw_instance1(self, instances, data_type):
|
||||
pass
|
||||
|
||||
@Visualizer.register_task('instances', force=True)
|
||||
def draw_instance2(self, instances, data_type):
|
||||
pass
|
||||
|
||||
det_visualizer = DetVisualizer2()
|
||||
assert len(det_visualizer.task_dict) == 1
|
||||
assert 'instances' in det_visualizer.task_dict
|
||||
assert det_visualizer.task_dict[
|
||||
'instances'].__name__ == 'draw_instance2'
|
141
tests/test_visualizer/test_writer.py
Normal file
141
tests/test_visualizer/test_writer.py
Normal file
@ -0,0 +1,141 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import random
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.data import BaseDataElement, BaseDataSample
|
||||
from mmengine.visualizer import (VISUALIZERS, LocalWriter, TensorboardWriter,
|
||||
WandbWriter)
|
||||
|
||||
|
||||
def get_demo_datasample():
|
||||
metainfo = dict(
|
||||
img_id=random.randint(0, 100),
|
||||
img_shape=(random.randint(400, 600), random.randint(400, 600)))
|
||||
gt_instances = BaseDataElement(
|
||||
data=dict(bboxes=torch.rand((5, 4)), labels=torch.rand((5, ))))
|
||||
pred_instances = BaseDataElement(
|
||||
data=dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))))
|
||||
data = dict(gt_instances=gt_instances, pred_instances=pred_instances)
|
||||
data_sample = BaseDataSample(data=data, metainfo=metainfo)
|
||||
return data_sample
|
||||
|
||||
|
||||
class TestLocalWriter:
|
||||
|
||||
def test_add_image(self):
|
||||
image = np.random.randint(0, 256, size=(10, 10, 3))
|
||||
data_sample = get_demo_datasample()
|
||||
|
||||
local_writer = LocalWriter(visuailzer=dict(type='Visualizer'))
|
||||
local_writer.add_image('img', image)
|
||||
local_writer.add_image('img', image, data_sample)
|
||||
|
||||
bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]])
|
||||
local_writer.visualizer.draw_bboxes(bboxes)
|
||||
local_writer.add_image('img', local_writer.visualizer.get_image())
|
||||
|
||||
visuailzer = VISUALIZERS.build(dict(type='Visualizer'))
|
||||
local_writer = LocalWriter(visuailzer=visuailzer)
|
||||
local_writer.add_image('img', image)
|
||||
local_writer.add_image('img', image, data_sample)
|
||||
|
||||
# test `visuailzer` parameter
|
||||
# `visuailzer` parameter which must be either Visualizer instance
|
||||
# or valid dictionary.
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
class A:
|
||||
pass
|
||||
|
||||
LocalWriter(visuailzer=A())
|
||||
with pytest.raises(AssertionError):
|
||||
LocalWriter(visuailzer=dict(a='Visualizer'))
|
||||
|
||||
# test not visuailzer
|
||||
# The visuailzer parameter must be set when
|
||||
# the local_writer object is instantiated and
|
||||
# the `add_image` method is called.
|
||||
with pytest.raises(AssertionError):
|
||||
local_writer = LocalWriter()
|
||||
local_writer.add_image('img', image)
|
||||
|
||||
def test_add_scaler(self):
|
||||
local_writer = LocalWriter()
|
||||
local_writer.add_scaler('map', 0.9)
|
||||
|
||||
def test_add_hyperparams(self):
|
||||
local_writer = LocalWriter()
|
||||
local_writer.add_hyperparams('hyper', dict(lr=0.01))
|
||||
|
||||
|
||||
class TestWandbWriter:
|
||||
sys.modules['wandb'] = MagicMock()
|
||||
|
||||
def test_add_image(self):
|
||||
image = np.random.randint(0, 256, size=(10, 10, 3))
|
||||
data_sample = get_demo_datasample()
|
||||
|
||||
wandb_writer = WandbWriter()
|
||||
assert not wandb_writer.visualizer
|
||||
wandb_writer.add_image('img', image, data_sample)
|
||||
|
||||
wandb_writer = WandbWriter(visuailzer=dict(type='Visualizer'))
|
||||
assert wandb_writer.visualizer
|
||||
wandb_writer.add_image('img', image)
|
||||
wandb_writer.add_image('img', image, data_sample)
|
||||
|
||||
wandb_writer.visualizer.set_image(image)
|
||||
wandb_writer.add_image('img', wandb_writer.visualizer.get_image())
|
||||
|
||||
# TODO test file exist
|
||||
|
||||
def test_add_scaler(self):
|
||||
wandb_writer = WandbWriter()
|
||||
wandb_writer.add_scaler('map', 0.9)
|
||||
|
||||
def test_add_hyperparams(self):
|
||||
wandb_writer = WandbWriter()
|
||||
wandb_writer.add_hyperparams('hyper', dict(lr=0.01))
|
||||
|
||||
|
||||
class TestTensorboardWriter:
|
||||
sys.modules['torch.utils.tensorboard.SummaryWriter'] = MagicMock()
|
||||
|
||||
def test_add_image(self):
|
||||
image = np.random.randint(0, 256, size=(10, 10, 3))
|
||||
data_sample = get_demo_datasample()
|
||||
|
||||
tensorboard_writer = TensorboardWriter()
|
||||
assert not tensorboard_writer.visualizer
|
||||
tensorboard_writer.add_image('img', image, data_sample)
|
||||
|
||||
tensorboard_writer = TensorboardWriter(
|
||||
visuailzer=dict(type='Visualizer'))
|
||||
assert tensorboard_writer.visualizer
|
||||
tensorboard_writer.add_image('img', image)
|
||||
tensorboard_writer.add_image('img', image, data_sample)
|
||||
|
||||
tensorboard_writer.visualizer.set_image(image)
|
||||
tensorboard_writer.add_image('img',
|
||||
tensorboard_writer.visualizer.get_image())
|
||||
|
||||
# test no visualizer
|
||||
# The visuailzer parameter must be set when
|
||||
# the tensorboard_writer object is instantiated and
|
||||
# the `add_image` method is called.
|
||||
with pytest.raises(AssertionError):
|
||||
tensorboard_writer = TensorboardWriter()
|
||||
tensorboard_writer.add_image('img', image)
|
||||
|
||||
def test_add_scaler(self):
|
||||
tensorboard_writer = TensorboardWriter()
|
||||
tensorboard_writer.add_scaler('map', 0.9)
|
||||
|
||||
def test_add_hyperparams(self):
|
||||
tensorboard_writer = TensorboardWriter()
|
||||
tensorboard_writer.add_hyperparams('hyper', dict(lr=0.01))
|
Loading…
x
Reference in New Issue
Block a user