mirror of https://github.com/open-mmlab/mmcv.git
79 lines
2.3 KiB
Python
79 lines
2.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmcv.cnn.bricks import LayerScale, Scale
|
|
|
|
|
|
def test_scale():
|
|
# test default scale
|
|
scale = Scale()
|
|
assert scale.scale.data == 1.
|
|
assert scale.scale.dtype == torch.float
|
|
x = torch.rand(1, 3, 64, 64)
|
|
output = scale(x)
|
|
assert output.shape == (1, 3, 64, 64)
|
|
|
|
# test given scale
|
|
scale = Scale(10.)
|
|
assert scale.scale.data == 10.
|
|
assert scale.scale.dtype == torch.float
|
|
x = torch.rand(1, 3, 64, 64)
|
|
output = scale(x)
|
|
assert output.shape == (1, 3, 64, 64)
|
|
|
|
|
|
def test_layer_scale():
|
|
with pytest.raises(AssertionError):
|
|
cfg = dict(
|
|
dim=10,
|
|
data_format='BNC',
|
|
)
|
|
LayerScale(**cfg)
|
|
|
|
# test init
|
|
cfg = dict(dim=10)
|
|
ls = LayerScale(**cfg)
|
|
assert torch.equal(ls.weight, torch.ones(10, requires_grad=True) * 1e-5)
|
|
|
|
# test forward
|
|
# test channels_last
|
|
cfg = dict(dim=256, inplace=False, data_format='channels_last')
|
|
ls_channels_last = LayerScale(**cfg)
|
|
x = torch.randn((4, 49, 256))
|
|
out = ls_channels_last(x)
|
|
assert tuple(out.size()) == (4, 49, 256)
|
|
assert torch.equal(x * 1e-5, out)
|
|
|
|
# test channels_last 2d
|
|
cfg = dict(dim=256, inplace=False, data_format='channels_last')
|
|
ls_channels_last = LayerScale(**cfg)
|
|
x = torch.randn((4, 7, 49, 256))
|
|
out = ls_channels_last(x)
|
|
assert tuple(out.size()) == (4, 7, 49, 256)
|
|
assert torch.equal(x * 1e-5, out)
|
|
|
|
# test channels_first
|
|
cfg = dict(dim=256, inplace=False, data_format='channels_first')
|
|
ls_channels_first = LayerScale(**cfg)
|
|
x = torch.randn((4, 256, 7, 7))
|
|
out = ls_channels_first(x)
|
|
assert tuple(out.size()) == (4, 256, 7, 7)
|
|
assert torch.equal(x * 1e-5, out)
|
|
|
|
# test channels_first 3D
|
|
cfg = dict(dim=256, inplace=False, data_format='channels_first')
|
|
ls_channels_first = LayerScale(**cfg)
|
|
x = torch.randn((4, 256, 7, 7, 7))
|
|
out = ls_channels_first(x)
|
|
assert tuple(out.size()) == (4, 256, 7, 7, 7)
|
|
assert torch.equal(x * 1e-5, out)
|
|
|
|
# test inplace True
|
|
cfg = dict(dim=256, inplace=True, data_format='channels_first')
|
|
ls_channels_first = LayerScale(**cfg)
|
|
x = torch.randn((4, 256, 7, 7))
|
|
out = ls_channels_first(x)
|
|
assert tuple(out.size()) == (4, 256, 7, 7)
|
|
assert x is out
|