1
0
mirror of https://github.com/open-mmlab/mmrazor.git synced 2025-06-03 15:02:54 +08:00
whcao 1e8f886523
[Feature]Feature map visualization ()
* WIP: vis

* WIP: add visualization

* WIP: add visualization hook

* WIP: support razor visualizer

* WIP

* WIP: wrap draw_featmap

* support feature map visualization

* add a demo image for visualization

* fix typos

* change eps to 1e-6

* add pytest for visualization

* fix vis hook

* fix arguments' name

* fix img path

* support draw inference results

* add visualization doc

* fix figure url

* move files

Co-authored-by: weihan cao <HIT-cwh>
2022-10-26 13:26:20 +08:00

129 lines
4.7 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import pytest
import torch
from mmengine.visualization import Visualizer
from mmrazor.visualization.local_visualizer import modify
class TestVisualizer(TestCase):
def setUp(self):
"""Setup the demo image in every test method.
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""
self.image = np.random.randint(
0, 256, size=(10, 10, 3)).astype('uint8')
def test_draw_featmap(self):
visualizer = Visualizer()
visualizer.draw_featmap = modify
image = np.random.randint(0, 256, size=(3, 3, 3), dtype='uint8')
# must be Tensor
with pytest.raises(
AssertionError,
match='`featmap` should be torch.Tensor, but got '
"<class 'numpy.ndarray'>"):
visualizer.draw_featmap(np.ones((3, 3, 3)))
# test tensor format
with pytest.raises(
AssertionError, match='Input dimension must be 3, but got 4'):
visualizer.draw_featmap(torch.randn(1, 1, 3, 3))
# test overlaid_image shape
with pytest.warns(Warning):
visualizer.draw_featmap(torch.randn(1, 4, 3), overlaid_image=image)
# test resize_shape
featmap = visualizer.draw_featmap(
torch.randn(1, 4, 3), resize_shape=(6, 7))
assert featmap.shape[:2] == (6, 7)
featmap = visualizer.draw_featmap(
torch.randn(1, 4, 3), overlaid_image=image, resize_shape=(6, 7))
assert featmap.shape[:2] == (6, 7)
# test channel_reduction parameter
# mode only supports 'squeeze_mean' and 'select_max'
with pytest.raises(AssertionError):
visualizer.draw_featmap(
torch.randn(2, 3, 3), channel_reduction='xx')
featmap = visualizer.draw_featmap(
torch.randn(2, 3, 3), channel_reduction='squeeze_mean')
assert featmap.shape[:2] == (3, 3)
featmap = visualizer.draw_featmap(
torch.randn(2, 3, 3), channel_reduction='select_max')
assert featmap.shape[:2] == (3, 3)
featmap = visualizer.draw_featmap(
torch.randn(2, 3, 3), channel_reduction='pixel_wise_max')
assert featmap.shape[:2] == (3, 3)
featmap = visualizer.draw_featmap(
torch.randn(2, 4, 3),
overlaid_image=image,
channel_reduction='pixel_wise_max')
assert featmap.shape[:2] == (3, 3)
# test topk parameter
with pytest.raises(
AssertionError,
match='The input tensor channel dimension must be 1 or 3 '
'when topk is less than 1, but the channel '
'dimension you input is 6, you can use the '
'channel_reduction parameter or set topk '
'greater than 0 to solve the error'):
visualizer.draw_featmap(
torch.randn(6, 3, 3), channel_reduction=None, topk=0)
featmap = visualizer.draw_featmap(
torch.randn(6, 3, 3), channel_reduction='select_max', topk=10)
assert featmap.shape[:2] == (3, 3)
featmap = visualizer.draw_featmap(
torch.randn(1, 4, 3), channel_reduction=None, topk=-1)
assert featmap.shape[:2] == (4, 3)
featmap = visualizer.draw_featmap(
torch.randn(3, 4, 3),
overlaid_image=image,
channel_reduction=None,
topk=-1)
assert featmap.shape[:2] == (3, 3)
featmap = visualizer.draw_featmap(
torch.randn(6, 3, 3),
channel_reduction=None,
topk=4,
arrangement=(2, 2))
assert featmap.shape[:2] == (6, 6)
featmap = visualizer.draw_featmap(
torch.randn(6, 3, 3),
channel_reduction=None,
topk=4,
arrangement=(1, 4))
assert featmap.shape[:2] == (3, 12)
with pytest.raises(
AssertionError,
match='The product of row and col in the `arrangement` '
'is less than topk, please set '
'the `arrangement` correctly'):
visualizer.draw_featmap(
torch.randn(6, 3, 3),
channel_reduction=None,
topk=4,
arrangement=(1, 2))
# test gray
featmap = visualizer.draw_featmap(
torch.randn(6, 3, 3),
overlaid_image=np.random.randint(
0, 256, size=(3, 3), dtype='uint8'),
channel_reduction=None,
topk=4,
arrangement=(2, 2))
assert featmap.shape[:2] == (6, 6)