[Feature] Add swin-transformer model. (#271)

* Add swin transformer archs S, B and L.

* Add SwinTransformer configs

* Add train config files of swin.

* Align init method with original code

* Use nn.Unfold to merge patch

* Change all ConfigDict to dict

* Add init_cfg for all subclasses of BaseModule.

* Use mmcv version init function

* Add Swin README

* Use safer cfg copy method

* Improve docstring and variable name.

* Fix some difference in randaug

Fix BGR bug, align scheduler config.

Fix label smoothing parameter difference.

* Fix missing droppath in attn

* Fix bug of relative posititon table if window width is not equal to
height.

* Make `PatchMerging` more general, support kernel, stride, padding and
dilation.

* Rename `residual` to `identity` in attention and FFN.

* Add `auto_pad` option to auto pad feature map

* Improve docstring.

* Fix bug in ShiftWMSA padding.

* Remove unused `key` and `value` in ShiftWMSA

* Move `PatchMerging` into utils and use common `PatchEmbed`.

* Use latest `LinearClsHead`, train augments and label smooth settings.
And remove original `SwinLinearClsHead`.

* Mark some configs as "Evalution Only".

* Remove useless comment in config

* 1. Move ShiftWindowMSA and WindowMSA to `utils/attention.py`
2. Add docstrings of each module.
3. Fix some variables' names.
4. Other small improvement.

* Add unit tests of swin-transformer and patchmerging.

* Fix some bugs in unit tests.

* Fix bug of rel_position_index if window is not square.

* Make WindowMSA implicit, and add unit tests.

* Add metafile.yml, update readme and model_zoo.
pull/338/head
Ma Zerun 2021-07-01 09:30:42 +08:00 committed by GitHub
parent 4ebee155e8
commit 076ee10cac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1569 additions and 4 deletions

View File

@ -49,6 +49,7 @@ Supported backbones:
- [x] ShuffleNetV2
- [x] MobileNetV2
- [x] MobileNetV3
- [x] Swin-Transformer
## Installation

View File

@ -49,6 +49,7 @@ MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [O
- [x] ShuffleNetV2
- [x] MobileNetV2
- [x] MobileNetV3
- [x] Swin-Transformer
## 安装

View File

@ -0,0 +1,122 @@
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
policies = [
dict(type='AutoContrast'),
dict(type='Equalize'),
dict(type='Invert'),
dict(
type='Rotate',
interpolation='bicubic',
magnitude_key='angle',
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
magnitude_range=(0, 30)),
dict(type='Posterize', magnitude_key='bits', magnitude_range=(4, 0)),
dict(type='Solarize', magnitude_key='thr', magnitude_range=(256, 0)),
dict(
type='SolarizeAdd',
magnitude_key='magnitude',
magnitude_range=(0, 110)),
dict(
type='ColorTransform',
magnitude_key='magnitude',
magnitude_range=(0, 0.9)),
dict(type='Contrast', magnitude_key='magnitude', magnitude_range=(0, 0.9)),
dict(
type='Brightness', magnitude_key='magnitude',
magnitude_range=(0, 0.9)),
dict(
type='Sharpness', magnitude_key='magnitude', magnitude_range=(0, 0.9)),
dict(
type='Shear',
interpolation='bicubic',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
direction='horizontal'),
dict(
type='Shear',
interpolation='bicubic',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
direction='vertical'),
dict(
type='Translate',
interpolation='bicubic',
magnitude_key='magnitude',
magnitude_range=(0, 0.45),
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
direction='horizontal'),
dict(
type='Translate',
interpolation='bicubic',
magnitude_key='magnitude',
magnitude_range=(0, 0.45),
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
direction='vertical')
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies=policies,
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=img_norm_cfg['mean'][::-1],
fill_std=img_norm_cfg['std'][::-1]),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
size=(256, -1),
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=128,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_prefix='data/imagenet/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline),
test=dict(
# replace `data/val` with `data/test` for standard test
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline))
evaluation = dict(interval=10, metric='accuracy')

View File

@ -0,0 +1,43 @@
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=384,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=384, backend='pillow', interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=128,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_prefix='data/imagenet/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline),
test=dict(
# replace `data/val` with `data/test` for standard test
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline))
evaluation = dict(interval=10, metric='accuracy')

View File

@ -0,0 +1,22 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='SwinTransformer', arch='base', img_size=224, drop_path_rate=0.5),
neck=dict(type='GlobalAveragePooling', dim=1),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))

View File

@ -0,0 +1,16 @@
# model settings
# Only for evaluation
model = dict(
type='ImageClassifier',
backbone=dict(
type='SwinTransformer',
arch='base',
img_size=384,
stage_cfg=dict(block_cfg=dict(window_size=12))),
neck=dict(type='GlobalAveragePooling', dim=1),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5)))

View File

@ -0,0 +1,12 @@
# model settings
# Only for evaluation
model = dict(
type='ImageClassifier',
backbone=dict(type='SwinTransformer', arch='large', img_size=224),
neck=dict(type='GlobalAveragePooling', dim=1),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1536,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5)))

View File

@ -0,0 +1,16 @@
# model settings
# Only for evaluation
model = dict(
type='ImageClassifier',
backbone=dict(
type='SwinTransformer',
arch='large',
img_size=384,
stage_cfg=dict(block_cfg=dict(window_size=12))),
neck=dict(type='GlobalAveragePooling', dim=1),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1536,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5)))

View File

@ -0,0 +1,23 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='SwinTransformer', arch='small', img_size=224,
drop_path_rate=0.3),
neck=dict(type='GlobalAveragePooling', dim=1),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))

View File

@ -0,0 +1,22 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='SwinTransformer', arch='tiny', img_size=224, drop_path_rate=0.2),
neck=dict(type='GlobalAveragePooling', dim=1),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))

View File

@ -0,0 +1,30 @@
paramwise_cfg = dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0,
custom_keys={
'.absolute_pos_embed': dict(decay_mult=0.0),
'.relative_position_bias_table': dict(decay_mult=0.0)
})
# for batch in each gpu is 128, 8 gpu
# lr = 5e-4 * 128 * 8 / 512 = 0.001
optimizer = dict(
type='AdamW',
lr=5e-4 * 128 * 8 / 512,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
paramwise_cfg=paramwise_cfg)
optimizer_config = dict(grad_clip=dict(max_norm=5.0))
# learning policy
lr_config = dict(
policy='CosineAnnealing',
by_epoch=False,
min_lr_ratio=1e-2,
warmup='linear',
warmup_ratio=1e-3,
warmup_iters=20 * 1252,
warmup_by_epoch=False)
runner = dict(type='EpochBasedRunner', max_epochs=300)

View File

@ -0,0 +1,41 @@
# Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
## Introduction
[ALGORITHM]
```latex
@article{liu2021Swin,
title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
journal={arXiv preprint arXiv:2103.14030},
year={2021}
}
```
## Pretrain model
The pre-trained modles are converted from [model zoo of Swin Transformer](https://github.com/microsoft/Swin-Transformer#main-results-on-imagenet-with-pretrained-models).
### ImageNet 1k
| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download |
|:---------:|:------------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:--------:|
| Swin-T | ImageNet-1k | 224x224 | 28.29 | 4.36 | 81.18 | 95.52 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_tiny_patch4_window7_224-160bb0a5.pth)|
| Swin-S | ImageNet-1k | 224x224 | 49.61 | 8.52 | 83.21 | 96.25 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_small_patch4_window7_224-cc7a01c9.pth)|
| Swin-B | ImageNet-1k | 224x224 | 87.77 | 15.14 | 83.42 | 96.44 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224-4670dd19.pth)|
| Swin-B | ImageNet-1k | 384x384 | 87.90 | 44.49 | 84.49 | 96.95 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window12_384-02c598a4.pth)|
| Swin-B | ImageNet-22k | 224x224 | 87.77 | 15.14 | 85.16 | 97.50 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224_22kto1k-f967f799.pth)|
| Swin-B | ImageNet-22k | 384x384 | 87.90 | 44.49 | 86.44 | 98.05 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window12_384_22kto1k-d59b0d1d.pth)|
| Swin-L | ImageNet-22k | 224x224 | 196.53 | 34.04 | 86.24 | 97.88 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window7_224_22kto1k-5f0996db.pth)|
| Swin-L | ImageNet-22k | 384x384 | 196.74 | 100.04 | 87.25 | 98.25 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window12_384_22kto1k-0a40944b.pth)|
## Results and models
### ImageNet
| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------:|:------------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:----------:|:--------:|
| Swin-T | ImageNet-1k | 224x224 | 28.29 | 4.36 | 81.18 | 95.61 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_tiny_224_imagenet.py) |[model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.log.json)|
| Swin-S | ImageNet-1k | 224x224 | 49.61 | 8.52 | 83.02 | 96.29 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_small_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.log.json)|
| Swin-B | ImageNet-1k | 224x224 | 87.77 | 15.14 | 83.36 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_base_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.log.json)|

View File

@ -0,0 +1,67 @@
Collections:
- Name: Swin-Transformer
Metadata:
Training Data: ImageNet
Training Techniques:
- AdamW
- Weight Decay
Training Resources: 16x V100 GPUs
Epochs: 300
Batch Size: 1024
Architecture:
- Shift Window Multihead Self Attention
Paper: https://arxiv.org/pdf/2103.14030.pdf
README: configs/swin_transformer/README.md
Models:
- Config: configs/swin_transformer/swin_tiny_224_imagenet.py
In Collection: Swin-Transformer
Metadata:
FLOPs: 4360000000
Parameters: 28290000
Training Data: ImageNet
Training Resources: 16x 1080 GPUs
Epochs: 300
Batch Size: 1024
Name: swin_tiny_224_imagenet
Results:
- Dataset: ImageNet
Metrics:
Top 1 Accuracy: 81.18
Top 5 Accuracy: 95.61
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.pth
- Config: configs/swin_transformer/swin_small_224_imagenet.py
In Collection: Swin-Transformer
Metadata:
FLOPs: 8520000000
Parameters: 48610000
Training Data: ImageNet
Training Resources: 16x 1080 GPUs
Epochs: 300
Batch Size: 1024
Name: swin_small_224_imagenet
Results:
- Dataset: ImageNet
Metrics:
Top 1 Accuracy: 83.02
Top 5 Accuracy: 96.29
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.pth
- Config: configs/swin_transformer/swin_base_224_imagenet.py
In Collection: Swin-Transformer
Metadata:
FLOPs: 15140000000
Parameters: 87770000
Training Data: ImageNet
Training Resources: 16x 1080 GPUs
Epochs: 300
Batch Size: 1024
Name: swin_base_224_imagenet
Results:
- Dataset: ImageNet
Metrics:
Top 1 Accuracy: 83.36
Top 5 Accuracy: 96.44
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.pth

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/swin_transformer/base_224.py',
'../_base_/datasets/imagenet_bs128_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py'
]

View File

@ -0,0 +1,7 @@
# Only for evaluation
_base_ = [
'../_base_/models/swin_transformer/base_384.py',
'../_base_/datasets/imagenet_bs128_swin_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py'
]

View File

@ -0,0 +1,7 @@
# Only for evaluation
_base_ = [
'../_base_/models/swin_transformer/large_224.py',
'../_base_/datasets/imagenet_bs128_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py'
]

View File

@ -0,0 +1,7 @@
# Only for evaluation
_base_ = [
'../_base_/models/swin_transformer/large_384.py',
'../_base_/datasets/imagenet_bs128_swin_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py'
]

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/swin_transformer/small_224.py',
'../_base_/datasets/imagenet_bs128_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py'
]

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/swin_transformer/tiny_224.py',
'../_base_/datasets/imagenet_bs128_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py'
]

View File

@ -40,6 +40,9 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| ViT-B/32* | 88.3 | 8.56 | 81.73 | 96.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/vision_transformer/vit_base_patch32_384_finetune_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/vit/vit_base_patch32_384.pth) | [log]() |
| ViT-L/16* | 304.72 | 116.68 | 85.08 | 97.38 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/vision_transformer/vit_large_patch16_384_finetune_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/vit/vit_large_patch16_384.pth) | [log]() |
| ViT-L/32* | 306.63 | 29.66 | 81.52 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/vision_transformer/vit_large_patch32_384_finetune_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/vit/vit_large_patch32_384.pth) | [log]() |
| Swin-Transformer tiny | 28.29 | 4.36 | 81.18 | 95.61 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_tiny_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.log.json)|
| Swin-Transformer small| 49.61 | 8.52 | 83.02 | 96.29 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_small_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.log.json)|
| Swin-Transformer base | 87.77 | 15.14 | 83.36 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_base_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.log.json)|
Models with * are converted from other repos, others are trained by ourselves.

View File

@ -11,11 +11,13 @@ from .seresnet import SEResNet
from .seresnext import SEResNeXt
from .shufflenet_v1 import ShuffleNetV1
from .shufflenet_v2 import ShuffleNetV2
from .swin_transformer import SwinTransformer
from .vgg import VGG
from .vision_transformer import VisionTransformer
__all__ = [
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer'
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
'SwinTransformer'
]

View File

@ -0,0 +1,349 @@
from copy import deepcopy
from typing import Sequence
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from ..builder import BACKBONES
from ..utils import PatchEmbed, PatchMerging, ShiftWindowMSA
from .base_backbone import BaseBackbone
class SwinBlock(BaseModule):
"""Swin Transformer block.
Args:
embed_dims (int): Number of input channels.
input_resolution (Tuple[int, int]): The resolution of the input feature
map.
num_heads (int): Number of attention heads.
window_size (int, optional): The height and width of the window.
Defaults to 7.
shift (bool, optional): Shift the attention window or not.
Defaults to False.
ffn_ratio (float, optional): The expansion ratio of feedforward network
hidden layer channels. Defaults to 4.
drop_path (float, optional): The drop path rate after attention and
ffn. Defaults to 0.
attn_cfgs (dict, optional): The extra config of Shift Window-MSA.
Defaults to empty dict.
ffn_cfgs (dict, optional): The extra config of FFN.
Defaults to empty dict.
norm_cfg (dict, optional): The config of norm layers.
Defaults to dict(type='LN').
auto_pad (bool, optional): Auto pad the feature map to be divisible by
window_size, Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
input_resolution,
num_heads,
window_size=7,
shift=False,
ffn_ratio=4.,
drop_path=0.,
attn_cfgs=dict(),
ffn_cfgs=dict(),
norm_cfg=dict(type='LN'),
auto_pad=False,
init_cfg=None):
super(SwinBlock, self).__init__(init_cfg)
_attn_cfgs = {
'embed_dims': embed_dims,
'input_resolution': input_resolution,
'num_heads': num_heads,
'shift_size': window_size // 2 if shift else 0,
'window_size': window_size,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'auto_pad': auto_pad,
**attn_cfgs
}
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = ShiftWindowMSA(**_attn_cfgs)
_ffn_cfgs = {
'embed_dims': embed_dims,
'feedforward_channels': int(embed_dims * ffn_ratio),
'num_fcs': 2,
'ffn_drop': 0,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'act_cfg': dict(type='GELU'),
**ffn_cfgs
}
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = FFN(**_ffn_cfgs)
def forward(self, x):
identity = x
x = self.norm1(x)
x = self.attn(x)
x = x + identity
identity = x
x = self.norm2(x)
x = self.ffn(x, identity=identity)
return x
class SwinBlockSequence(BaseModule):
"""Module with successive Swin Transformer blocks and downsample layer.
Args:
embed_dims (int): Number of input channels.
input_resolution (Tuple[int, int]): The resolution of the input feature
map.
depth (int): Number of successive swin transformer blocks.
num_heads (int): Number of attention heads.
downsample (bool, optional): Downsample the output of blocks by patch
merging. Defaults to False.
downsample_cfg (dict, optional): The extra config of the patch merging
layer. Defaults to empty dict.
drop_paths (Sequence[float] | float, optional): The drop path rate in
each block. Defaults to 0.
block_cfgs (Sequence[dict] | dict, optional): The extra config of each
block. Defaults to empty dicts.
auto_pad (bool, optional): Auto pad the feature map to be divisible by
window_size, Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
input_resolution,
depth,
num_heads,
downsample=False,
downsample_cfg=dict(),
drop_paths=0.,
block_cfgs=dict(),
auto_pad=False,
init_cfg=None):
super().__init__(init_cfg)
if not isinstance(drop_paths, Sequence):
drop_paths = [drop_paths] * depth
if not isinstance(block_cfgs, Sequence):
block_cfg = [deepcopy(block_cfgs) for _ in range(depth)]
self.blocks = ModuleList()
for i in range(depth):
_block_cfg = {
'embed_dims': embed_dims,
'input_resolution': input_resolution,
'num_heads': num_heads,
'shift': False if i % 2 == 0 else True,
'drop_path': drop_paths[i],
'auto_pad': auto_pad,
**block_cfg[i]
}
block = SwinBlock(**_block_cfg)
self.blocks.append(block)
if downsample:
_downsample_cfg = {
'input_resolution': input_resolution,
'in_channels': embed_dims,
'expansion_ratio': 2,
'norm_cfg': dict(type='LN'),
**downsample_cfg
}
self.downsample = PatchMerging(**_downsample_cfg)
else:
self.downsample = None
def forward(self, x):
for block in self.blocks:
x = block(x)
if self.downsample:
x = self.downsample(x)
return x
@BACKBONES.register_module()
class SwinTransformer(BaseBackbone):
""" Swin Transformer
A PyTorch implement of : `Swin Transformer:
Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/abs/2103.14030
Inspiration from
https://github.com/microsoft/Swin-Transformer
Args:
arch (str | dict): Swin Transformer architecture
Defaults to 'T'.
img_size (int | tuple): The size of input image.
Defaults to 224.
in_channels (int): The num of input channels.
Defaults to 3.
drop_rate (float): Dropout rate after embedding.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate.
Defaults to 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
auto_pad (bool): If True, auto pad feature map to fit window_size.
Defaults to False.
norm_cfg (dict, optional): Config dict for normalization layer at end
of backone. Defaults to dict(type='LN')
stage_cfgs (Sequence | dict, optional): Extra config dict for each
stage. Defaults to empty dict.
patch_cfg (dict, optional): Extra config dict for patch embedding.
Defaults to empty dict.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmcls.models import SwinTransformer
>>> import torch
>>> extra_config = dict(
>>> arch='tiny',
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
>>> 'expansion_ratio': 3}),
>>> auto_pad=True)
>>> self = SwinTransformer(**extra_config)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> output = self.forward(inputs)
>>> print(output.shape)
(1, 2592, 4)
"""
arch_zoo = {
**dict.fromkeys(['t', 'tiny'],
{'embed_dims': 96,
'depths': [2, 2, 6, 2],
'num_heads': [3, 6, 12, 24]}),
**dict.fromkeys(['s', 'small'],
{'embed_dims': 96,
'depths': [2, 2, 18, 2],
'num_heads': [3, 6, 12, 24]}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32]}),
**dict.fromkeys(['l', 'large'],
{'embed_dims': 192,
'depths': [2, 2, 18, 2],
'num_heads': [6, 12, 24, 48]}),
} # yapf: disable
def __init__(self,
arch='T',
img_size=224,
in_channels=3,
drop_rate=0.,
drop_path_rate=0.1,
use_abs_pos_embed=False,
auto_pad=False,
norm_cfg=dict(type='LN'),
stage_cfgs=dict(),
patch_cfg=dict(),
init_cfg=None):
super(SwinTransformer, self).__init__(init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {'embed_dims', 'depths', 'num_head'}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.depths = self.arch_settings['depths']
self.num_heads = self.arch_settings['num_heads']
self.num_layers = len(self.depths)
self.use_abs_pos_embed = use_abs_pos_embed
self.auto_pad = auto_pad
_patch_cfg = dict(
img_size=img_size,
in_channels=in_channels,
embed_dims=self.embed_dims,
conv_cfg=dict(
type='Conv2d', kernel_size=4, stride=4, padding=0, dilation=1),
norm_cfg=dict(type='LN'),
**patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
if self.use_abs_pos_embed:
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, num_patches, self.embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
# stochastic depth
total_depth = sum(self.depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
self.stages = ModuleList()
embed_dims = self.embed_dims
input_resolution = patches_resolution
for i, (depth,
num_heads) in enumerate(zip(self.depths, self.num_heads)):
if isinstance(stage_cfgs, Sequence):
stage_cfg = stage_cfgs[i]
else:
stage_cfg = deepcopy(stage_cfgs)
downsample = True if i < self.num_layers - 1 else False
_stage_cfg = {
'embed_dims': embed_dims,
'depth': depth,
'num_heads': num_heads,
'downsample': downsample,
'input_resolution': input_resolution,
'drop_paths': dpr[:depth],
'auto_pad': auto_pad,
**stage_cfg
}
stage = SwinBlockSequence(**_stage_cfg)
self.stages.append(stage)
dpr = dpr[depth:]
if downsample:
embed_dims = stage.downsample.out_channels
input_resolution = stage.downsample.output_resolution
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
def init_weights(self):
super(SwinTransformer, self).init_weights()
if self.use_abs_pos_embed:
trunc_normal_(self.absolute_pos_embed, std=0.02)
def forward(self, x):
x = self.patch_embed(x)
if self.use_abs_pos_embed:
x = x + self.absolute_pos_embed
x = self.drop_after_pos(x)
for stage in self.stages:
x = stage(x)
x = self.norm(x) if self.norm else x
return x.transpose(1, 2)

View File

@ -1,6 +1,7 @@
from .attention import ShiftWindowMSA
from .augment.augments import Augments
from .channel_shuffle import channel_shuffle
from .embed import HybridEmbed, PatchEmbed
from .embed import HybridEmbed, PatchEmbed, PatchMerging
from .helpers import to_2tuple, to_3tuple, to_4tuple, to_ntuple
from .inverted_residual import InvertedResidual
from .make_divisible import make_divisible
@ -8,6 +9,6 @@ from .se_layer import SELayer
__all__ = [
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer',
'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'Augments',
'HybridEmbed', 'PatchEmbed'
'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'PatchEmbed',
'PatchMerging', 'HybridEmbed', 'Augments', 'ShiftWindowMSA'
]

View File

@ -0,0 +1,289 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.cnn.bricks.transformer import build_dropout
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule
from .helpers import to_2tuple
class WindowMSA(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative
position bias.
Args:
embed_dims (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Defaults to True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Defaults to None.
attn_drop (float, optional): Dropout ratio of attention weight.
Defaults to 0.
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
init_cfg=None):
super().__init__(init_cfg)
self.embed_dims = embed_dims
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.scale = qk_scale or head_embed_dims**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# About 2x faster than original impl
Wh, Ww = self.window_size
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
rel_position_index = rel_index_coords + rel_index_coords.T
rel_position_index = rel_position_index.flip(1).contiguous()
self.register_buffer('relative_position_index', rel_position_index)
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def init_weights(self):
super(WindowMSA, self).init_weights()
trunc_normal_(self.relative_position_bias_table, std=0.02)
def forward(self, x, mask=None):
"""
Args:
x (tensor): input features with shape of (num_windows*B, N, C)
mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww,
Wh*Ww), value should be between (-inf, 0].
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@staticmethod
def double_step_seq(step1, len1, step2, len2):
seq1 = torch.arange(0, step1 * len1, step1)
seq2 = torch.arange(0, step2 * len2, step2)
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
@ATTENTION.register_module()
class ShiftWindowMSA(BaseModule):
"""Shift Window Multihead Self-Attention Module.
Args:
embed_dims (int): Number of input channels.
input_resolution (Tuple[int, int]): The resolution of the input feature
map.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window.
shift_size (int, optional): The shift step of each window towards
right-bottom. If zero, act as regular window-msa. Defaults to 0.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Defaults to None.
attn_drop (float, optional): Dropout ratio of attention weight.
Defaults to 0.0.
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
dropout_layer (dict, optional): The dropout_layer used before output.
Defaults to dict(type='DropPath', drop_prob=0.).
auto_pad (bool, optional): Auto pad the feature map to be divisible by
window_size, Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
input_resolution,
num_heads,
window_size,
shift_size=0,
qkv_bias=True,
qk_scale=None,
attn_drop=0,
proj_drop=0,
dropout_layer=dict(type='DropPath', drop_prob=0.),
auto_pad=False,
init_cfg=None):
super().__init__(init_cfg)
self.embed_dims = embed_dims
self.input_resolution = input_resolution
self.shift_size = shift_size
self.window_size = window_size
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, don't partition
self.shift_size = 0
self.window_size = min(self.input_resolution)
self.w_msa = WindowMSA(embed_dims, to_2tuple(self.window_size),
num_heads, qkv_bias, qk_scale, attn_drop,
proj_drop)
self.drop = build_dropout(dropout_layer)
H, W = self.input_resolution
# Handle auto padding
self.auto_pad = auto_pad
if self.auto_pad:
self.pad_r = (self.window_size -
W % self.window_size) % self.window_size
self.pad_b = (self.window_size -
H % self.window_size) % self.window_size
self.H_pad = H + self.pad_b
self.W_pad = W + self.pad_r
else:
H_pad, W_pad = self.input_resolution
assert H_pad % self.window_size + W_pad % self.window_size == 0,\
f'input_resolution({self.input_resolution}) is not divisible '\
f'by window_size({self.window_size}). Please check feature '\
f'map shape or set `auto_pad=True`.'
self.H_pad, self.W_pad = H_pad, W_pad
self.pad_r, self.pad_b = 0, 0
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, self.H_pad, self.W_pad, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# nW, window_size, window_size, 1
mask_windows = self.window_partition(img_mask)
mask_windows = mask_windows.view(
-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-100.0)).masked_fill(
attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer('attn_mask', attn_mask)
def forward(self, query):
H, W = self.input_resolution
B, L, C = query.shape
assert L == H * W, 'input feature has wrong size'
query = query.view(B, H, W, C)
if self.pad_r or self.pad_b:
query = F.pad(query, (0, 0, 0, self.pad_r, 0, self.pad_b))
# cyclic shift
if self.shift_size > 0:
shifted_query = torch.roll(
query,
shifts=(-self.shift_size, -self.shift_size),
dims=(1, 2))
else:
shifted_query = query
# nW*B, window_size, window_size, C
query_windows = self.window_partition(shifted_query)
# nW*B, window_size*window_size, C
query_windows = query_windows.view(-1, self.window_size**2, C)
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
attn_windows = self.w_msa(query_windows, mask=self.attn_mask)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size,
self.window_size, C)
# B H' W' C
shifted_x = self.window_reverse(attn_windows, self.H_pad, self.W_pad)
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(
shifted_x,
shifts=(self.shift_size, self.shift_size),
dims=(1, 2))
else:
x = shifted_x
if self.pad_r or self.pad_b:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = self.drop(x)
return x
def window_reverse(self, windows, H, W):
window_size = self.window_size
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
def window_partition(self, x):
B, H, W, C = x.shape
window_size = self.window_size
x = x.view(B, H // window_size, window_size, W // window_size,
window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows

View File

@ -159,3 +159,92 @@ class HybridEmbed(BaseModule):
x = x[-1]
x = self.projection(x).flatten(2).transpose(1, 2)
return x
class PatchMerging(BaseModule):
"""Merge patch feature map.
This layer use nn.Unfold to group feature map by kernel_size, and use norm
and linear layer to embed grouped feature map.
Args:
input_resolution (tuple): The size of input patch resolution.
in_channels (int): The num of input channels.
expansion_ratio (Number): Expansion ratio of output channels. The num
of output channels is equal to int(expansion_ratio * in_channels).
kernel_size (int | tuple, optional): the kernel size in the unfold
layer. Defaults to 2.
stride (int | tuple, optional): the stride of the sliding blocks in the
unfold layer. Defaults to be equal with kernel_size.
padding (int | tuple, optional): zero padding width in the unfold
layer. Defaults to 0.
dilation (int | tuple, optional): dilation parameter in the unfold
layer. Defaults to 1.
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults to False.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults to dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
input_resolution,
in_channels,
expansion_ratio,
kernel_size=2,
stride=None,
padding=0,
dilation=1,
bias=False,
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg)
H, W = input_resolution
self.input_resolution = input_resolution
self.in_channels = in_channels
self.out_channels = int(expansion_ratio * in_channels)
if stride is None:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
padding = to_2tuple(padding)
dilation = to_2tuple(dilation)
self.sampler = nn.Unfold(kernel_size, dilation, padding, stride)
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
self.reduction = nn.Linear(sample_dim, self.out_channels, bias=bias)
# See https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
H_out = (H + 2 * padding[0] - dilation[0] *
(kernel_size[0] - 1) - 1) // stride[0] + 1
W_out = (W + 2 * padding[1] - dilation[1] *
(kernel_size[1] - 1) - 1) // stride[1] + 1
self.output_resolution = (H_out, W_out)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
x = self.sampler(x) # B, 4*C, H/2*W/2
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
x = self.norm(x) if self.norm else x
x = self.reduction(x)
return x

View File

@ -0,0 +1,177 @@
import numpy as np
import torch
from mmcls.models.utils.attention import ShiftWindowMSA, WindowMSA
def get_relative_position_index(window_size):
"""Method from original code of Swin-Transformer."""
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
# 2, Wh*Ww, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
# Wh*Ww, Wh*Ww, 2
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
return relative_position_index
def test_window_msa():
batch_size = 1
num_windows = (4, 4)
embed_dims = 96
window_size = (7, 7)
num_heads = 4
attn = WindowMSA(
embed_dims=embed_dims, window_size=window_size, num_heads=num_heads)
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
window_size[0] * window_size[1], embed_dims))
# test forward
output = attn(inputs)
assert output.shape == inputs.shape
assert attn.relative_position_bias_table.shape == (
(2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
# test relative_position_bias_table init
attn.init_weights()
assert abs(attn.relative_position_bias_table).sum() > 0
# test non-square window_size
window_size = (6, 7)
attn = WindowMSA(
embed_dims=embed_dims, window_size=window_size, num_heads=num_heads)
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
window_size[0] * window_size[1], embed_dims))
output = attn(inputs)
assert output.shape == inputs.shape
# test relative_position_index
expected_rel_pos_index = get_relative_position_index(window_size)
assert (attn.relative_position_index == expected_rel_pos_index).all()
# test qkv_bias=True
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
qkv_bias=True)
assert attn.qkv.bias.shape == (embed_dims * 3, )
# test qkv_bias=False
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
qkv_bias=False)
assert attn.qkv.bias is None
# test default qk_scale
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
qk_scale=None)
head_dims = embed_dims // num_heads
assert np.isclose(attn.scale, head_dims**-0.5)
# test specified qk_scale
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
qk_scale=0.3)
assert attn.scale == 0.3
# test attn_drop
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
attn_drop=1.0)
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
window_size[0] * window_size[1], embed_dims))
# drop all attn output, output shuold be equal to proj.bias
assert torch.allclose(attn(inputs), attn.proj.bias)
# test prob_drop
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
proj_drop=1.0)
assert (attn(inputs) == 0).all()
def test_shift_window_msa():
batch_size = 1
embed_dims = 96
input_resolution = (14, 14)
num_heads = 4
window_size = 7
# test forward
attn = ShiftWindowMSA(
embed_dims=embed_dims,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size)
inputs = torch.rand(
(batch_size, input_resolution[0] * input_resolution[1], embed_dims))
output = attn(inputs)
assert output.shape == (inputs.shape)
assert attn.w_msa.relative_position_bias_table.shape == ((2 * window_size -
1)**2, num_heads)
# test forward with shift_size
attn = ShiftWindowMSA(
embed_dims=embed_dims,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=1)
output = attn(inputs)
assert output.shape == (inputs.shape)
# test relative_position_bias_table init
attn.init_weights()
assert abs(attn.w_msa.relative_position_bias_table).sum() > 0
# test dropout_layer
attn = ShiftWindowMSA(
embed_dims=embed_dims,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
dropout_layer=dict(type='DropPath', drop_prob=0.5))
torch.manual_seed(0)
output = attn(inputs)
assert (output == 0).all()
# test auto_pad
input_resolution = (19, 18)
attn = ShiftWindowMSA(
embed_dims=embed_dims,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
auto_pad=True)
assert attn.pad_r == 3
assert attn.pad_b == 2
# test small input_resolution
input_resolution = (5, 6)
attn = ShiftWindowMSA(
embed_dims=embed_dims,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=3,
auto_pad=True)
assert attn.window_size == 5
assert attn.shift_size == 0

View File

@ -0,0 +1,56 @@
import pytest
import torch
from mmcls.models.utils import PatchMerging
def cal_unfold_dim(dim, kernel_size, stride, padding=0, dilation=1):
return (dim + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
def test_patch_merging():
settings = dict(
input_resolution=(56, 56), in_channels=16, expansion_ratio=2)
downsample = PatchMerging(**settings)
# test forward with wrong dims
with pytest.raises(AssertionError):
inputs = torch.rand((1, 16, 56 * 56))
downsample(inputs)
# test patch merging forward
inputs = torch.rand((1, 56 * 56, 16))
out = downsample(inputs)
assert downsample.output_resolution == (28, 28)
assert out.shape == (1, 28 * 28, 32)
# test different kernel_size in each direction
downsample = PatchMerging(kernel_size=(2, 3), **settings)
out = downsample(inputs)
expected_dim = cal_unfold_dim(56, 2, 2) * cal_unfold_dim(56, 3, 3)
assert downsample.sampler.kernel_size == (2, 3)
assert downsample.output_resolution == (cal_unfold_dim(56, 2, 2),
cal_unfold_dim(56, 3, 3))
assert out.shape == (1, expected_dim, 32)
# test default stride
downsample = PatchMerging(kernel_size=6, **settings)
assert downsample.sampler.stride == (6, 6)
# test stride=3
downsample = PatchMerging(kernel_size=6, stride=3, **settings)
out = downsample(inputs)
assert downsample.sampler.stride == (3, 3)
assert out.shape == (1, cal_unfold_dim(56, 6, stride=3)**2, 32)
# test padding
downsample = PatchMerging(kernel_size=6, padding=2, **settings)
out = downsample(inputs)
assert downsample.sampler.padding == (2, 2)
assert out.shape == (1, cal_unfold_dim(56, 6, 6, padding=2)**2, 32)
# test dilation
downsample = PatchMerging(kernel_size=6, dilation=2, **settings)
out = downsample(inputs)
assert downsample.sampler.dilation == (2, 2)
assert out.shape == (1, cal_unfold_dim(56, 6, 6, dilation=2)**2, 32)

View File

@ -0,0 +1,144 @@
from math import ceil
import numpy as np
import pytest
import torch
from mmcls.models.backbones import SwinTransformer
def test_swin_transformer():
"""Test Swin Transformer backbone."""
with pytest.raises(AssertionError):
# Swin Transformer arch string should be in
SwinTransformer(arch='unknown')
with pytest.raises(AssertionError):
# Swin Transformer arch dict should include 'embed_dims',
# 'depths' and 'num_head' keys.
SwinTransformer(arch=dict(embed_dims=96, depths=[2, 2, 18, 2]))
# Test tiny arch forward
model = SwinTransformer(arch='Tiny')
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
output = model(imgs)
assert output.shape == (1, 768, 49)
# Test small arch forward
model = SwinTransformer(arch='small')
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
output = model(imgs)
assert output.shape == (1, 768, 49)
# Test base arch forward
model = SwinTransformer(arch='B')
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
output = model(imgs)
assert output.shape == (1, 1024, 49)
# Test large arch forward
model = SwinTransformer(arch='l')
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
output = model(imgs)
assert output.shape == (1, 1536, 49)
# Test base arch with window_size=12, image_size=384
model = SwinTransformer(
arch='base',
img_size=384,
stage_cfgs=dict(block_cfgs=dict(window_size=12)))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 384, 384)
output = model(imgs)
assert output.shape == (1, 1024, 144)
# Test small with use_abs_pos_embed = True
model = SwinTransformer(arch='small', use_abs_pos_embed=True)
model.init_weights()
model.train()
assert model.absolute_pos_embed.shape == (1, 3136, 96)
# Test small with use_abs_pos_embed = False
with pytest.raises(AttributeError):
model = SwinTransformer(arch='small', use_abs_pos_embed=False)
model.absolute_pos_embed
# Test small with auto_pad = True
model = SwinTransformer(
arch='small',
auto_pad=True,
stage_cfgs=dict(
block_cfgs={'window_size': 7},
downsample_cfg={
'kernel_size': (3, 2),
}))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
# stage 1
input_h = int(224 / 4 / 3)
expect_h = ceil(input_h / 7) * 7
input_w = int(224 / 4 / 2)
expect_w = ceil(input_w / 7) * 7
assert model.stages[1].blocks[0].attn.pad_b == expect_h - input_h
assert model.stages[1].blocks[0].attn.pad_r == expect_w - input_w
# stage 2
input_h = int(224 / 4 / 3 / 3)
# input_h is smaller than window_size, shrink the window_size to input_h.
expect_h = input_h
input_w = int(224 / 4 / 2 / 2)
expect_w = ceil(input_w / input_h) * input_h
assert model.stages[2].blocks[0].attn.pad_b == expect_h - input_h
assert model.stages[2].blocks[0].attn.pad_r == expect_w - input_w
# stage 3
input_h = int(224 / 4 / 3 / 3 / 3)
expect_h = input_h
input_w = int(224 / 4 / 2 / 2 / 2)
expect_w = ceil(input_w / input_h) * input_h
assert model.stages[3].blocks[0].attn.pad_b == expect_h - input_h
assert model.stages[3].blocks[0].attn.pad_r == expect_w - input_w
# Test small with auto_pad = False
with pytest.raises(AssertionError):
model = SwinTransformer(
arch='small',
auto_pad=False,
stage_cfgs=dict(
block_cfgs={'window_size': 7},
downsample_cfg={
'kernel_size': (3, 2),
}))
# Test drop_path_rate decay
model = SwinTransformer(
arch='small',
drop_path_rate=0.2,
)
depths = model.arch_settings['depths']
pos = 0
for i, depth in enumerate(depths):
for j in range(depth):
block = model.stages[i].blocks[j]
expect_prob = 0.2 / (sum(depths) - 1) * pos
assert np.isclose(block.ffn.dropout_layer.drop_prob, expect_prob)
assert np.isclose(block.attn.drop.drop_prob, expect_prob)
pos += 1