mmcv/tests/test_cnn/test_scale.py
takuoko d510b8b174
[Feature] Support LayerScale in FFN (#2451)
* add layer scale

* add layer scale

* add layer scale

* Update mmcv/cnn/bricks/transformer.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/cnn/bricks/transformer.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* add layer scale

* move LayerScale

* add layer_scale_init_value

* add typehint

* fix for tensor with any dim

* fix layer scale rule

* fix layer scale rule

* fix test

* add docs

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
2022-12-11 17:58:19 +08:00

79 lines
2.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.cnn.bricks import LayerScale, Scale
def test_scale():
# test default scale
scale = Scale()
assert scale.scale.data == 1.
assert scale.scale.dtype == torch.float
x = torch.rand(1, 3, 64, 64)
output = scale(x)
assert output.shape == (1, 3, 64, 64)
# test given scale
scale = Scale(10.)
assert scale.scale.data == 10.
assert scale.scale.dtype == torch.float
x = torch.rand(1, 3, 64, 64)
output = scale(x)
assert output.shape == (1, 3, 64, 64)
def test_layer_scale():
with pytest.raises(AssertionError):
cfg = dict(
dim=10,
data_format='BNC',
)
LayerScale(**cfg)
# test init
cfg = dict(dim=10)
ls = LayerScale(**cfg)
assert torch.equal(ls.weight, torch.ones(10, requires_grad=True) * 1e-5)
# test forward
# test channels_last
cfg = dict(dim=256, inplace=False, data_format='channels_last')
ls_channels_last = LayerScale(**cfg)
x = torch.randn((4, 49, 256))
out = ls_channels_last(x)
assert tuple(out.size()) == (4, 49, 256)
assert torch.equal(x * 1e-5, out)
# test channels_last 2d
cfg = dict(dim=256, inplace=False, data_format='channels_last')
ls_channels_last = LayerScale(**cfg)
x = torch.randn((4, 7, 49, 256))
out = ls_channels_last(x)
assert tuple(out.size()) == (4, 7, 49, 256)
assert torch.equal(x * 1e-5, out)
# test channels_first
cfg = dict(dim=256, inplace=False, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7))
out = ls_channels_first(x)
assert tuple(out.size()) == (4, 256, 7, 7)
assert torch.equal(x * 1e-5, out)
# test channels_first 3D
cfg = dict(dim=256, inplace=False, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7, 7))
out = ls_channels_first(x)
assert tuple(out.size()) == (4, 256, 7, 7, 7)
assert torch.equal(x * 1e-5, out)
# test inplace True
cfg = dict(dim=256, inplace=True, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7))
out = ls_channels_first(x)
assert tuple(out.size()) == (4, 256, 7, 7)
assert x is out