49 lines
1.5 KiB
Python
49 lines
1.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
|
|
from mmcls.models.utils import LayerScale
|
|
|
|
|
|
class TestLayerScale(TestCase):
|
|
|
|
def test_init(self):
|
|
with self.assertRaisesRegex(AssertionError, "'data_format' could"):
|
|
cfg = dict(
|
|
dim=10,
|
|
inplace=False,
|
|
data_format='BNC',
|
|
)
|
|
LayerScale(**cfg)
|
|
|
|
cfg = dict(dim=10)
|
|
ls = LayerScale(**cfg)
|
|
assert torch.equal(ls.weight,
|
|
torch.ones(10, requires_grad=True) * 1e-5)
|
|
|
|
def forward(self):
|
|
# Test channels_last
|
|
cfg = dict(dim=256, inplace=False, data_format='channels_last')
|
|
ls_channels_last = LayerScale(**cfg)
|
|
x = torch.randn((4, 49, 256))
|
|
out = ls_channels_last(x)
|
|
self.assertEqual(tuple(out.size()), (4, 49, 256))
|
|
assert torch.equal(x * 1e-5, out)
|
|
|
|
# Test channels_first
|
|
cfg = dict(dim=256, inplace=False, data_format='channels_first')
|
|
ls_channels_first = LayerScale(**cfg)
|
|
x = torch.randn((4, 256, 7, 7))
|
|
out = ls_channels_first(x)
|
|
self.assertEqual(tuple(out.size()), (4, 256, 7, 7))
|
|
assert torch.equal(x * 1e-5, out)
|
|
|
|
# Test inplace True
|
|
cfg = dict(dim=256, inplace=True, data_format='channels_first')
|
|
ls_channels_first = LayerScale(**cfg)
|
|
x = torch.randn((4, 256, 7, 7))
|
|
out = ls_channels_first(x)
|
|
self.assertEqual(tuple(out.size()), (4, 256, 7, 7))
|
|
self.assertIs(x, out)
|