129 lines
4.7 KiB
Python
129 lines
4.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from mmengine.visualization import Visualizer
|
|
|
|
from mmrazor.visualization.local_visualizer import modify
|
|
|
|
|
|
class TestVisualizer(TestCase):
|
|
|
|
def setUp(self):
|
|
"""Setup the demo image in every test method.
|
|
|
|
TestCase calls functions in this order: setUp() -> testMethod() ->
|
|
tearDown() -> cleanUp()
|
|
"""
|
|
self.image = np.random.randint(
|
|
0, 256, size=(10, 10, 3)).astype('uint8')
|
|
|
|
def test_draw_featmap(self):
|
|
visualizer = Visualizer()
|
|
visualizer.draw_featmap = modify
|
|
image = np.random.randint(0, 256, size=(3, 3, 3), dtype='uint8')
|
|
|
|
# must be Tensor
|
|
with pytest.raises(
|
|
AssertionError,
|
|
match='`featmap` should be torch.Tensor, but got '
|
|
"<class 'numpy.ndarray'>"):
|
|
visualizer.draw_featmap(np.ones((3, 3, 3)))
|
|
|
|
# test tensor format
|
|
with pytest.raises(
|
|
AssertionError, match='Input dimension must be 3, but got 4'):
|
|
visualizer.draw_featmap(torch.randn(1, 1, 3, 3))
|
|
|
|
# test overlaid_image shape
|
|
with pytest.warns(Warning):
|
|
visualizer.draw_featmap(torch.randn(1, 4, 3), overlaid_image=image)
|
|
|
|
# test resize_shape
|
|
featmap = visualizer.draw_featmap(
|
|
torch.randn(1, 4, 3), resize_shape=(6, 7))
|
|
assert featmap.shape[:2] == (6, 7)
|
|
featmap = visualizer.draw_featmap(
|
|
torch.randn(1, 4, 3), overlaid_image=image, resize_shape=(6, 7))
|
|
assert featmap.shape[:2] == (6, 7)
|
|
|
|
# test channel_reduction parameter
|
|
# mode only supports 'squeeze_mean' and 'select_max'
|
|
with pytest.raises(AssertionError):
|
|
visualizer.draw_featmap(
|
|
torch.randn(2, 3, 3), channel_reduction='xx')
|
|
|
|
featmap = visualizer.draw_featmap(
|
|
torch.randn(2, 3, 3), channel_reduction='squeeze_mean')
|
|
assert featmap.shape[:2] == (3, 3)
|
|
featmap = visualizer.draw_featmap(
|
|
torch.randn(2, 3, 3), channel_reduction='select_max')
|
|
assert featmap.shape[:2] == (3, 3)
|
|
featmap = visualizer.draw_featmap(
|
|
torch.randn(2, 3, 3), channel_reduction='pixel_wise_max')
|
|
assert featmap.shape[:2] == (3, 3)
|
|
featmap = visualizer.draw_featmap(
|
|
torch.randn(2, 4, 3),
|
|
overlaid_image=image,
|
|
channel_reduction='pixel_wise_max')
|
|
assert featmap.shape[:2] == (3, 3)
|
|
|
|
# test topk parameter
|
|
with pytest.raises(
|
|
AssertionError,
|
|
match='The input tensor channel dimension must be 1 or 3 '
|
|
'when topk is less than 1, but the channel '
|
|
'dimension you input is 6, you can use the '
|
|
'channel_reduction parameter or set topk '
|
|
'greater than 0 to solve the error'):
|
|
visualizer.draw_featmap(
|
|
torch.randn(6, 3, 3), channel_reduction=None, topk=0)
|
|
|
|
featmap = visualizer.draw_featmap(
|
|
torch.randn(6, 3, 3), channel_reduction='select_max', topk=10)
|
|
assert featmap.shape[:2] == (3, 3)
|
|
featmap = visualizer.draw_featmap(
|
|
torch.randn(1, 4, 3), channel_reduction=None, topk=-1)
|
|
assert featmap.shape[:2] == (4, 3)
|
|
|
|
featmap = visualizer.draw_featmap(
|
|
torch.randn(3, 4, 3),
|
|
overlaid_image=image,
|
|
channel_reduction=None,
|
|
topk=-1)
|
|
assert featmap.shape[:2] == (3, 3)
|
|
featmap = visualizer.draw_featmap(
|
|
torch.randn(6, 3, 3),
|
|
channel_reduction=None,
|
|
topk=4,
|
|
arrangement=(2, 2))
|
|
assert featmap.shape[:2] == (6, 6)
|
|
featmap = visualizer.draw_featmap(
|
|
torch.randn(6, 3, 3),
|
|
channel_reduction=None,
|
|
topk=4,
|
|
arrangement=(1, 4))
|
|
assert featmap.shape[:2] == (3, 12)
|
|
with pytest.raises(
|
|
AssertionError,
|
|
match='The product of row and col in the `arrangement` '
|
|
'is less than topk, please set '
|
|
'the `arrangement` correctly'):
|
|
visualizer.draw_featmap(
|
|
torch.randn(6, 3, 3),
|
|
channel_reduction=None,
|
|
topk=4,
|
|
arrangement=(1, 2))
|
|
|
|
# test gray
|
|
featmap = visualizer.draw_featmap(
|
|
torch.randn(6, 3, 3),
|
|
overlaid_image=np.random.randint(
|
|
0, 256, size=(3, 3), dtype='uint8'),
|
|
channel_reduction=None,
|
|
topk=4,
|
|
arrangement=(2, 2))
|
|
assert featmap.shape[:2] == (6, 6)
|