Ma Zerun 076ee10cac
[Feature] Add swin-transformer model. (#271)
* Add swin transformer archs S, B and L.

* Add SwinTransformer configs

* Add train config files of swin.

* Align init method with original code

* Use nn.Unfold to merge patch

* Change all ConfigDict to dict

* Add init_cfg for all subclasses of BaseModule.

* Use mmcv version init function

* Add Swin README

* Use safer cfg copy method

* Improve docstring and variable name.

* Fix some difference in randaug

Fix BGR bug, align scheduler config.

Fix label smoothing parameter difference.

* Fix missing droppath in attn

* Fix bug of relative posititon table if window width is not equal to
height.

* Make `PatchMerging` more general, support kernel, stride, padding and
dilation.

* Rename `residual` to `identity` in attention and FFN.

* Add `auto_pad` option to auto pad feature map

* Improve docstring.

* Fix bug in ShiftWMSA padding.

* Remove unused `key` and `value` in ShiftWMSA

* Move `PatchMerging` into utils and use common `PatchEmbed`.

* Use latest `LinearClsHead`, train augments and label smooth settings.
And remove original `SwinLinearClsHead`.

* Mark some configs as "Evalution Only".

* Remove useless comment in config

* 1. Move ShiftWindowMSA and WindowMSA to `utils/attention.py`
2. Add docstrings of each module.
3. Fix some variables' names.
4. Other small improvement.

* Add unit tests of swin-transformer and patchmerging.

* Fix some bugs in unit tests.

* Fix bug of rel_position_index if window is not square.

* Make WindowMSA implicit, and add unit tests.

* Add metafile.yml, update readme and model_zoo.
2021-07-01 09:30:42 +08:00

57 lines
2.0 KiB
Python

import pytest
import torch
from mmcls.models.utils import PatchMerging
def cal_unfold_dim(dim, kernel_size, stride, padding=0, dilation=1):
return (dim + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
def test_patch_merging():
settings = dict(
input_resolution=(56, 56), in_channels=16, expansion_ratio=2)
downsample = PatchMerging(**settings)
# test forward with wrong dims
with pytest.raises(AssertionError):
inputs = torch.rand((1, 16, 56 * 56))
downsample(inputs)
# test patch merging forward
inputs = torch.rand((1, 56 * 56, 16))
out = downsample(inputs)
assert downsample.output_resolution == (28, 28)
assert out.shape == (1, 28 * 28, 32)
# test different kernel_size in each direction
downsample = PatchMerging(kernel_size=(2, 3), **settings)
out = downsample(inputs)
expected_dim = cal_unfold_dim(56, 2, 2) * cal_unfold_dim(56, 3, 3)
assert downsample.sampler.kernel_size == (2, 3)
assert downsample.output_resolution == (cal_unfold_dim(56, 2, 2),
cal_unfold_dim(56, 3, 3))
assert out.shape == (1, expected_dim, 32)
# test default stride
downsample = PatchMerging(kernel_size=6, **settings)
assert downsample.sampler.stride == (6, 6)
# test stride=3
downsample = PatchMerging(kernel_size=6, stride=3, **settings)
out = downsample(inputs)
assert downsample.sampler.stride == (3, 3)
assert out.shape == (1, cal_unfold_dim(56, 6, stride=3)**2, 32)
# test padding
downsample = PatchMerging(kernel_size=6, padding=2, **settings)
out = downsample(inputs)
assert downsample.sampler.padding == (2, 2)
assert out.shape == (1, cal_unfold_dim(56, 6, 6, padding=2)**2, 32)
# test dilation
downsample = PatchMerging(kernel_size=6, dilation=2, **settings)
out = downsample(inputs)
assert downsample.sampler.dilation == (2, 2)
assert out.shape == (1, cal_unfold_dim(56, 6, 6, dilation=2)**2, 32)