mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* Add swin transformer archs S, B and L. * Add SwinTransformer configs * Add train config files of swin. * Align init method with original code * Use nn.Unfold to merge patch * Change all ConfigDict to dict * Add init_cfg for all subclasses of BaseModule. * Use mmcv version init function * Add Swin README * Use safer cfg copy method * Improve docstring and variable name. * Fix some difference in randaug Fix BGR bug, align scheduler config. Fix label smoothing parameter difference. * Fix missing droppath in attn * Fix bug of relative posititon table if window width is not equal to height. * Make `PatchMerging` more general, support kernel, stride, padding and dilation. * Rename `residual` to `identity` in attention and FFN. * Add `auto_pad` option to auto pad feature map * Improve docstring. * Fix bug in ShiftWMSA padding. * Remove unused `key` and `value` in ShiftWMSA * Move `PatchMerging` into utils and use common `PatchEmbed`. * Use latest `LinearClsHead`, train augments and label smooth settings. And remove original `SwinLinearClsHead`. * Mark some configs as "Evalution Only". * Remove useless comment in config * 1. Move ShiftWindowMSA and WindowMSA to `utils/attention.py` 2. Add docstrings of each module. 3. Fix some variables' names. 4. Other small improvement. * Add unit tests of swin-transformer and patchmerging. * Fix some bugs in unit tests. * Fix bug of rel_position_index if window is not square. * Make WindowMSA implicit, and add unit tests. * Add metafile.yml, update readme and model_zoo.
57 lines
2.0 KiB
Python
57 lines
2.0 KiB
Python
import pytest
|
|
import torch
|
|
|
|
from mmcls.models.utils import PatchMerging
|
|
|
|
|
|
def cal_unfold_dim(dim, kernel_size, stride, padding=0, dilation=1):
|
|
return (dim + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
|
|
|
|
|
|
def test_patch_merging():
|
|
settings = dict(
|
|
input_resolution=(56, 56), in_channels=16, expansion_ratio=2)
|
|
downsample = PatchMerging(**settings)
|
|
|
|
# test forward with wrong dims
|
|
with pytest.raises(AssertionError):
|
|
inputs = torch.rand((1, 16, 56 * 56))
|
|
downsample(inputs)
|
|
|
|
# test patch merging forward
|
|
inputs = torch.rand((1, 56 * 56, 16))
|
|
out = downsample(inputs)
|
|
assert downsample.output_resolution == (28, 28)
|
|
assert out.shape == (1, 28 * 28, 32)
|
|
|
|
# test different kernel_size in each direction
|
|
downsample = PatchMerging(kernel_size=(2, 3), **settings)
|
|
out = downsample(inputs)
|
|
expected_dim = cal_unfold_dim(56, 2, 2) * cal_unfold_dim(56, 3, 3)
|
|
assert downsample.sampler.kernel_size == (2, 3)
|
|
assert downsample.output_resolution == (cal_unfold_dim(56, 2, 2),
|
|
cal_unfold_dim(56, 3, 3))
|
|
assert out.shape == (1, expected_dim, 32)
|
|
|
|
# test default stride
|
|
downsample = PatchMerging(kernel_size=6, **settings)
|
|
assert downsample.sampler.stride == (6, 6)
|
|
|
|
# test stride=3
|
|
downsample = PatchMerging(kernel_size=6, stride=3, **settings)
|
|
out = downsample(inputs)
|
|
assert downsample.sampler.stride == (3, 3)
|
|
assert out.shape == (1, cal_unfold_dim(56, 6, stride=3)**2, 32)
|
|
|
|
# test padding
|
|
downsample = PatchMerging(kernel_size=6, padding=2, **settings)
|
|
out = downsample(inputs)
|
|
assert downsample.sampler.padding == (2, 2)
|
|
assert out.shape == (1, cal_unfold_dim(56, 6, 6, padding=2)**2, 32)
|
|
|
|
# test dilation
|
|
downsample = PatchMerging(kernel_size=6, dilation=2, **settings)
|
|
out = downsample(inputs)
|
|
assert downsample.sampler.dilation == (2, 2)
|
|
assert out.shape == (1, cal_unfold_dim(56, 6, 6, dilation=2)**2, 32)
|