[Feature] Add swin-transformer model. (#271)
* Add swin transformer archs S, B and L. * Add SwinTransformer configs * Add train config files of swin. * Align init method with original code * Use nn.Unfold to merge patch * Change all ConfigDict to dict * Add init_cfg for all subclasses of BaseModule. * Use mmcv version init function * Add Swin README * Use safer cfg copy method * Improve docstring and variable name. * Fix some difference in randaug Fix BGR bug, align scheduler config. Fix label smoothing parameter difference. * Fix missing droppath in attn * Fix bug of relative posititon table if window width is not equal to height. * Make `PatchMerging` more general, support kernel, stride, padding and dilation. * Rename `residual` to `identity` in attention and FFN. * Add `auto_pad` option to auto pad feature map * Improve docstring. * Fix bug in ShiftWMSA padding. * Remove unused `key` and `value` in ShiftWMSA * Move `PatchMerging` into utils and use common `PatchEmbed`. * Use latest `LinearClsHead`, train augments and label smooth settings. And remove original `SwinLinearClsHead`. * Mark some configs as "Evalution Only". * Remove useless comment in config * 1. Move ShiftWindowMSA and WindowMSA to `utils/attention.py` 2. Add docstrings of each module. 3. Fix some variables' names. 4. Other small improvement. * Add unit tests of swin-transformer and patchmerging. * Fix some bugs in unit tests. * Fix bug of rel_position_index if window is not square. * Make WindowMSA implicit, and add unit tests. * Add metafile.yml, update readme and model_zoo.pull/338/head
parent
4ebee155e8
commit
076ee10cac
|
@ -49,6 +49,7 @@ Supported backbones:
|
|||
- [x] ShuffleNetV2
|
||||
- [x] MobileNetV2
|
||||
- [x] MobileNetV3
|
||||
- [x] Swin-Transformer
|
||||
|
||||
## Installation
|
||||
|
||||
|
|
|
@ -49,6 +49,7 @@ MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [O
|
|||
- [x] ShuffleNetV2
|
||||
- [x] MobileNetV2
|
||||
- [x] MobileNetV3
|
||||
- [x] Swin-Transformer
|
||||
|
||||
## 安装
|
||||
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
policies = [
|
||||
dict(type='AutoContrast'),
|
||||
dict(type='Equalize'),
|
||||
dict(type='Invert'),
|
||||
dict(
|
||||
type='Rotate',
|
||||
interpolation='bicubic',
|
||||
magnitude_key='angle',
|
||||
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
|
||||
magnitude_range=(0, 30)),
|
||||
dict(type='Posterize', magnitude_key='bits', magnitude_range=(4, 0)),
|
||||
dict(type='Solarize', magnitude_key='thr', magnitude_range=(256, 0)),
|
||||
dict(
|
||||
type='SolarizeAdd',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 110)),
|
||||
dict(
|
||||
type='ColorTransform',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.9)),
|
||||
dict(type='Contrast', magnitude_key='magnitude', magnitude_range=(0, 0.9)),
|
||||
dict(
|
||||
type='Brightness', magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.9)),
|
||||
dict(
|
||||
type='Sharpness', magnitude_key='magnitude', magnitude_range=(0, 0.9)),
|
||||
dict(
|
||||
type='Shear',
|
||||
interpolation='bicubic',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.3),
|
||||
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
|
||||
direction='horizontal'),
|
||||
dict(
|
||||
type='Shear',
|
||||
interpolation='bicubic',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.3),
|
||||
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
|
||||
direction='vertical'),
|
||||
dict(
|
||||
type='Translate',
|
||||
interpolation='bicubic',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.45),
|
||||
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
|
||||
direction='horizontal'),
|
||||
dict(
|
||||
type='Translate',
|
||||
interpolation='bicubic',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.45),
|
||||
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
|
||||
direction='vertical')
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies=policies,
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=img_norm_cfg['mean'][::-1],
|
||||
fill_std=img_norm_cfg['std'][::-1]),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
size=(256, -1),
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=128,
|
||||
workers_per_gpu=8,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
# replace `data/val` with `data/test` for standard test
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=10, metric='accuracy')
|
|
@ -0,0 +1,43 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=384,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', size=384, backend='pillow', interpolation='bicubic'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=128,
|
||||
workers_per_gpu=8,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
# replace `data/val` with `data/test` for standard test
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline))
|
||||
evaluation = dict(interval=10, metric='accuracy')
|
|
@ -0,0 +1,22 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='SwinTransformer', arch='base', img_size=224, drop_path_rate=0.5),
|
||||
neck=dict(type='GlobalAveragePooling', dim=1),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
|
@ -0,0 +1,16 @@
|
|||
# model settings
|
||||
# Only for evaluation
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='SwinTransformer',
|
||||
arch='base',
|
||||
img_size=384,
|
||||
stage_cfg=dict(block_cfg=dict(window_size=12))),
|
||||
neck=dict(type='GlobalAveragePooling', dim=1),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5)))
|
|
@ -0,0 +1,12 @@
|
|||
# model settings
|
||||
# Only for evaluation
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='SwinTransformer', arch='large', img_size=224),
|
||||
neck=dict(type='GlobalAveragePooling', dim=1),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1536,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5)))
|
|
@ -0,0 +1,16 @@
|
|||
# model settings
|
||||
# Only for evaluation
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='SwinTransformer',
|
||||
arch='large',
|
||||
img_size=384,
|
||||
stage_cfg=dict(block_cfg=dict(window_size=12))),
|
||||
neck=dict(type='GlobalAveragePooling', dim=1),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1536,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5)))
|
|
@ -0,0 +1,23 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='SwinTransformer', arch='small', img_size=224,
|
||||
drop_path_rate=0.3),
|
||||
neck=dict(type='GlobalAveragePooling', dim=1),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
|
@ -0,0 +1,22 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='SwinTransformer', arch='tiny', img_size=224, drop_path_rate=0.2),
|
||||
neck=dict(type='GlobalAveragePooling', dim=1),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
|
@ -0,0 +1,30 @@
|
|||
paramwise_cfg = dict(
|
||||
norm_decay_mult=0.0,
|
||||
bias_decay_mult=0.0,
|
||||
custom_keys={
|
||||
'.absolute_pos_embed': dict(decay_mult=0.0),
|
||||
'.relative_position_bias_table': dict(decay_mult=0.0)
|
||||
})
|
||||
|
||||
# for batch in each gpu is 128, 8 gpu
|
||||
# lr = 5e-4 * 128 * 8 / 512 = 0.001
|
||||
optimizer = dict(
|
||||
type='AdamW',
|
||||
lr=5e-4 * 128 * 8 / 512,
|
||||
weight_decay=0.05,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999),
|
||||
paramwise_cfg=paramwise_cfg)
|
||||
optimizer_config = dict(grad_clip=dict(max_norm=5.0))
|
||||
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='CosineAnnealing',
|
||||
by_epoch=False,
|
||||
min_lr_ratio=1e-2,
|
||||
warmup='linear',
|
||||
warmup_ratio=1e-3,
|
||||
warmup_iters=20 * 1252,
|
||||
warmup_by_epoch=False)
|
||||
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=300)
|
|
@ -0,0 +1,41 @@
|
|||
# Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
||||
|
||||
## Introduction
|
||||
|
||||
[ALGORITHM]
|
||||
|
||||
```latex
|
||||
@article{liu2021Swin,
|
||||
title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
|
||||
author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
|
||||
journal={arXiv preprint arXiv:2103.14030},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
## Pretrain model
|
||||
|
||||
The pre-trained modles are converted from [model zoo of Swin Transformer](https://github.com/microsoft/Swin-Transformer#main-results-on-imagenet-with-pretrained-models).
|
||||
|
||||
### ImageNet 1k
|
||||
|
||||
| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download |
|
||||
|:---------:|:------------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:--------:|
|
||||
| Swin-T | ImageNet-1k | 224x224 | 28.29 | 4.36 | 81.18 | 95.52 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_tiny_patch4_window7_224-160bb0a5.pth)|
|
||||
| Swin-S | ImageNet-1k | 224x224 | 49.61 | 8.52 | 83.21 | 96.25 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_small_patch4_window7_224-cc7a01c9.pth)|
|
||||
| Swin-B | ImageNet-1k | 224x224 | 87.77 | 15.14 | 83.42 | 96.44 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224-4670dd19.pth)|
|
||||
| Swin-B | ImageNet-1k | 384x384 | 87.90 | 44.49 | 84.49 | 96.95 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window12_384-02c598a4.pth)|
|
||||
| Swin-B | ImageNet-22k | 224x224 | 87.77 | 15.14 | 85.16 | 97.50 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224_22kto1k-f967f799.pth)|
|
||||
| Swin-B | ImageNet-22k | 384x384 | 87.90 | 44.49 | 86.44 | 98.05 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window12_384_22kto1k-d59b0d1d.pth)|
|
||||
| Swin-L | ImageNet-22k | 224x224 | 196.53 | 34.04 | 86.24 | 97.88 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window7_224_22kto1k-5f0996db.pth)|
|
||||
| Swin-L | ImageNet-22k | 384x384 | 196.74 | 100.04 | 87.25 | 98.25 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window12_384_22kto1k-0a40944b.pth)|
|
||||
|
||||
|
||||
## Results and models
|
||||
|
||||
### ImageNet
|
||||
| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
|:---------:|:------------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:----------:|:--------:|
|
||||
| Swin-T | ImageNet-1k | 224x224 | 28.29 | 4.36 | 81.18 | 95.61 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_tiny_224_imagenet.py) |[model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.log.json)|
|
||||
| Swin-S | ImageNet-1k | 224x224 | 49.61 | 8.52 | 83.02 | 96.29 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_small_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.log.json)|
|
||||
| Swin-B | ImageNet-1k | 224x224 | 87.77 | 15.14 | 83.36 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_base_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.log.json)|
|
|
@ -0,0 +1,67 @@
|
|||
Collections:
|
||||
- Name: Swin-Transformer
|
||||
Metadata:
|
||||
Training Data: ImageNet
|
||||
Training Techniques:
|
||||
- AdamW
|
||||
- Weight Decay
|
||||
Training Resources: 16x V100 GPUs
|
||||
Epochs: 300
|
||||
Batch Size: 1024
|
||||
Architecture:
|
||||
- Shift Window Multihead Self Attention
|
||||
Paper: https://arxiv.org/pdf/2103.14030.pdf
|
||||
README: configs/swin_transformer/README.md
|
||||
|
||||
Models:
|
||||
- Config: configs/swin_transformer/swin_tiny_224_imagenet.py
|
||||
In Collection: Swin-Transformer
|
||||
Metadata:
|
||||
FLOPs: 4360000000
|
||||
Parameters: 28290000
|
||||
Training Data: ImageNet
|
||||
Training Resources: 16x 1080 GPUs
|
||||
Epochs: 300
|
||||
Batch Size: 1024
|
||||
Name: swin_tiny_224_imagenet
|
||||
Results:
|
||||
- Dataset: ImageNet
|
||||
Metrics:
|
||||
Top 1 Accuracy: 81.18
|
||||
Top 5 Accuracy: 95.61
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.pth
|
||||
- Config: configs/swin_transformer/swin_small_224_imagenet.py
|
||||
In Collection: Swin-Transformer
|
||||
Metadata:
|
||||
FLOPs: 8520000000
|
||||
Parameters: 48610000
|
||||
Training Data: ImageNet
|
||||
Training Resources: 16x 1080 GPUs
|
||||
Epochs: 300
|
||||
Batch Size: 1024
|
||||
Name: swin_small_224_imagenet
|
||||
Results:
|
||||
- Dataset: ImageNet
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.02
|
||||
Top 5 Accuracy: 96.29
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.pth
|
||||
- Config: configs/swin_transformer/swin_base_224_imagenet.py
|
||||
In Collection: Swin-Transformer
|
||||
Metadata:
|
||||
FLOPs: 15140000000
|
||||
Parameters: 87770000
|
||||
Training Data: ImageNet
|
||||
Training Resources: 16x 1080 GPUs
|
||||
Epochs: 300
|
||||
Batch Size: 1024
|
||||
Name: swin_base_224_imagenet
|
||||
Results:
|
||||
- Dataset: ImageNet
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.36
|
||||
Top 5 Accuracy: 96.44
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.pth
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = [
|
||||
'../_base_/models/swin_transformer/base_224.py',
|
||||
'../_base_/datasets/imagenet_bs128_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
|
@ -0,0 +1,7 @@
|
|||
# Only for evaluation
|
||||
_base_ = [
|
||||
'../_base_/models/swin_transformer/base_384.py',
|
||||
'../_base_/datasets/imagenet_bs128_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
|
@ -0,0 +1,7 @@
|
|||
# Only for evaluation
|
||||
_base_ = [
|
||||
'../_base_/models/swin_transformer/large_224.py',
|
||||
'../_base_/datasets/imagenet_bs128_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
|
@ -0,0 +1,7 @@
|
|||
# Only for evaluation
|
||||
_base_ = [
|
||||
'../_base_/models/swin_transformer/large_384.py',
|
||||
'../_base_/datasets/imagenet_bs128_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = [
|
||||
'../_base_/models/swin_transformer/small_224.py',
|
||||
'../_base_/datasets/imagenet_bs128_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = [
|
||||
'../_base_/models/swin_transformer/tiny_224.py',
|
||||
'../_base_/datasets/imagenet_bs128_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
|
@ -40,6 +40,9 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
|
|||
| ViT-B/32* | 88.3 | 8.56 | 81.73 | 96.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/vision_transformer/vit_base_patch32_384_finetune_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/vit/vit_base_patch32_384.pth) | [log]() |
|
||||
| ViT-L/16* | 304.72 | 116.68 | 85.08 | 97.38 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/vision_transformer/vit_large_patch16_384_finetune_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/vit/vit_large_patch16_384.pth) | [log]() |
|
||||
| ViT-L/32* | 306.63 | 29.66 | 81.52 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/vision_transformer/vit_large_patch32_384_finetune_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/vit/vit_large_patch32_384.pth) | [log]() |
|
||||
| Swin-Transformer tiny | 28.29 | 4.36 | 81.18 | 95.61 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_tiny_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.log.json)|
|
||||
| Swin-Transformer small| 49.61 | 8.52 | 83.02 | 96.29 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_small_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.log.json)|
|
||||
| Swin-Transformer base | 87.77 | 15.14 | 83.36 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_base_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.log.json)|
|
||||
|
||||
Models with * are converted from other repos, others are trained by ourselves.
|
||||
|
||||
|
|
|
@ -11,11 +11,13 @@ from .seresnet import SEResNet
|
|||
from .seresnext import SEResNeXt
|
||||
from .shufflenet_v1 import ShuffleNetV1
|
||||
from .shufflenet_v2 import ShuffleNetV2
|
||||
from .swin_transformer import SwinTransformer
|
||||
from .vgg import VGG
|
||||
from .vision_transformer import VisionTransformer
|
||||
|
||||
__all__ = [
|
||||
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
|
||||
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
|
||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer'
|
||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
||||
'SwinTransformer'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,349 @@
|
|||
from copy import deepcopy
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import PatchEmbed, PatchMerging, ShiftWindowMSA
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
class SwinBlock(BaseModule):
|
||||
"""Swin Transformer block.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
input_resolution (Tuple[int, int]): The resolution of the input feature
|
||||
map.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int, optional): The height and width of the window.
|
||||
Defaults to 7.
|
||||
shift (bool, optional): Shift the attention window or not.
|
||||
Defaults to False.
|
||||
ffn_ratio (float, optional): The expansion ratio of feedforward network
|
||||
hidden layer channels. Defaults to 4.
|
||||
drop_path (float, optional): The drop path rate after attention and
|
||||
ffn. Defaults to 0.
|
||||
attn_cfgs (dict, optional): The extra config of Shift Window-MSA.
|
||||
Defaults to empty dict.
|
||||
ffn_cfgs (dict, optional): The extra config of FFN.
|
||||
Defaults to empty dict.
|
||||
norm_cfg (dict, optional): The config of norm layers.
|
||||
Defaults to dict(type='LN').
|
||||
auto_pad (bool, optional): Auto pad the feature map to be divisible by
|
||||
window_size, Defaults to False.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
shift=False,
|
||||
ffn_ratio=4.,
|
||||
drop_path=0.,
|
||||
attn_cfgs=dict(),
|
||||
ffn_cfgs=dict(),
|
||||
norm_cfg=dict(type='LN'),
|
||||
auto_pad=False,
|
||||
init_cfg=None):
|
||||
|
||||
super(SwinBlock, self).__init__(init_cfg)
|
||||
|
||||
_attn_cfgs = {
|
||||
'embed_dims': embed_dims,
|
||||
'input_resolution': input_resolution,
|
||||
'num_heads': num_heads,
|
||||
'shift_size': window_size // 2 if shift else 0,
|
||||
'window_size': window_size,
|
||||
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
|
||||
'auto_pad': auto_pad,
|
||||
**attn_cfgs
|
||||
}
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.attn = ShiftWindowMSA(**_attn_cfgs)
|
||||
|
||||
_ffn_cfgs = {
|
||||
'embed_dims': embed_dims,
|
||||
'feedforward_channels': int(embed_dims * ffn_ratio),
|
||||
'num_fcs': 2,
|
||||
'ffn_drop': 0,
|
||||
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
|
||||
'act_cfg': dict(type='GELU'),
|
||||
**ffn_cfgs
|
||||
}
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.ffn = FFN(**_ffn_cfgs)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.norm1(x)
|
||||
x = self.attn(x)
|
||||
x = x + identity
|
||||
|
||||
identity = x
|
||||
x = self.norm2(x)
|
||||
x = self.ffn(x, identity=identity)
|
||||
return x
|
||||
|
||||
|
||||
class SwinBlockSequence(BaseModule):
|
||||
"""Module with successive Swin Transformer blocks and downsample layer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
input_resolution (Tuple[int, int]): The resolution of the input feature
|
||||
map.
|
||||
depth (int): Number of successive swin transformer blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
downsample (bool, optional): Downsample the output of blocks by patch
|
||||
merging. Defaults to False.
|
||||
downsample_cfg (dict, optional): The extra config of the patch merging
|
||||
layer. Defaults to empty dict.
|
||||
drop_paths (Sequence[float] | float, optional): The drop path rate in
|
||||
each block. Defaults to 0.
|
||||
block_cfgs (Sequence[dict] | dict, optional): The extra config of each
|
||||
block. Defaults to empty dicts.
|
||||
auto_pad (bool, optional): Auto pad the feature map to be divisible by
|
||||
window_size, Defaults to False.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
input_resolution,
|
||||
depth,
|
||||
num_heads,
|
||||
downsample=False,
|
||||
downsample_cfg=dict(),
|
||||
drop_paths=0.,
|
||||
block_cfgs=dict(),
|
||||
auto_pad=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
if not isinstance(drop_paths, Sequence):
|
||||
drop_paths = [drop_paths] * depth
|
||||
|
||||
if not isinstance(block_cfgs, Sequence):
|
||||
block_cfg = [deepcopy(block_cfgs) for _ in range(depth)]
|
||||
|
||||
self.blocks = ModuleList()
|
||||
for i in range(depth):
|
||||
_block_cfg = {
|
||||
'embed_dims': embed_dims,
|
||||
'input_resolution': input_resolution,
|
||||
'num_heads': num_heads,
|
||||
'shift': False if i % 2 == 0 else True,
|
||||
'drop_path': drop_paths[i],
|
||||
'auto_pad': auto_pad,
|
||||
**block_cfg[i]
|
||||
}
|
||||
block = SwinBlock(**_block_cfg)
|
||||
self.blocks.append(block)
|
||||
|
||||
if downsample:
|
||||
_downsample_cfg = {
|
||||
'input_resolution': input_resolution,
|
||||
'in_channels': embed_dims,
|
||||
'expansion_ratio': 2,
|
||||
'norm_cfg': dict(type='LN'),
|
||||
**downsample_cfg
|
||||
}
|
||||
self.downsample = PatchMerging(**_downsample_cfg)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
if self.downsample:
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class SwinTransformer(BaseBackbone):
|
||||
""" Swin Transformer
|
||||
A PyTorch implement of : `Swin Transformer:
|
||||
Hierarchical Vision Transformer using Shifted Windows` -
|
||||
https://arxiv.org/abs/2103.14030
|
||||
|
||||
Inspiration from
|
||||
https://github.com/microsoft/Swin-Transformer
|
||||
|
||||
Args:
|
||||
arch (str | dict): Swin Transformer architecture
|
||||
Defaults to 'T'.
|
||||
img_size (int | tuple): The size of input image.
|
||||
Defaults to 224.
|
||||
in_channels (int): The num of input channels.
|
||||
Defaults to 3.
|
||||
drop_rate (float): Dropout rate after embedding.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate.
|
||||
Defaults to 0.1.
|
||||
use_abs_pos_embed (bool): If True, add absolute position embedding to
|
||||
the patch embedding. Defaults to False.
|
||||
auto_pad (bool): If True, auto pad feature map to fit window_size.
|
||||
Defaults to False.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer at end
|
||||
of backone. Defaults to dict(type='LN')
|
||||
stage_cfgs (Sequence | dict, optional): Extra config dict for each
|
||||
stage. Defaults to empty dict.
|
||||
patch_cfg (dict, optional): Extra config dict for patch embedding.
|
||||
Defaults to empty dict.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> from mmcls.models import SwinTransformer
|
||||
>>> import torch
|
||||
>>> extra_config = dict(
|
||||
>>> arch='tiny',
|
||||
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
|
||||
>>> 'expansion_ratio': 3}),
|
||||
>>> auto_pad=True)
|
||||
>>> self = SwinTransformer(**extra_config)
|
||||
>>> inputs = torch.rand(1, 3, 224, 224)
|
||||
>>> output = self.forward(inputs)
|
||||
>>> print(output.shape)
|
||||
(1, 2592, 4)
|
||||
"""
|
||||
arch_zoo = {
|
||||
**dict.fromkeys(['t', 'tiny'],
|
||||
{'embed_dims': 96,
|
||||
'depths': [2, 2, 6, 2],
|
||||
'num_heads': [3, 6, 12, 24]}),
|
||||
**dict.fromkeys(['s', 'small'],
|
||||
{'embed_dims': 96,
|
||||
'depths': [2, 2, 18, 2],
|
||||
'num_heads': [3, 6, 12, 24]}),
|
||||
**dict.fromkeys(['b', 'base'],
|
||||
{'embed_dims': 128,
|
||||
'depths': [2, 2, 18, 2],
|
||||
'num_heads': [4, 8, 16, 32]}),
|
||||
**dict.fromkeys(['l', 'large'],
|
||||
{'embed_dims': 192,
|
||||
'depths': [2, 2, 18, 2],
|
||||
'num_heads': [6, 12, 24, 48]}),
|
||||
} # yapf: disable
|
||||
|
||||
def __init__(self,
|
||||
arch='T',
|
||||
img_size=224,
|
||||
in_channels=3,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
use_abs_pos_embed=False,
|
||||
auto_pad=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
stage_cfgs=dict(),
|
||||
patch_cfg=dict(),
|
||||
init_cfg=None):
|
||||
super(SwinTransformer, self).__init__(init_cfg)
|
||||
|
||||
if isinstance(arch, str):
|
||||
arch = arch.lower()
|
||||
assert arch in set(self.arch_zoo), \
|
||||
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
|
||||
self.arch_settings = self.arch_zoo[arch]
|
||||
else:
|
||||
essential_keys = {'embed_dims', 'depths', 'num_head'}
|
||||
assert isinstance(arch, dict) and set(arch) == essential_keys, \
|
||||
f'Custom arch needs a dict with keys {essential_keys}'
|
||||
self.arch_settings = arch
|
||||
|
||||
self.embed_dims = self.arch_settings['embed_dims']
|
||||
self.depths = self.arch_settings['depths']
|
||||
self.num_heads = self.arch_settings['num_heads']
|
||||
self.num_layers = len(self.depths)
|
||||
self.use_abs_pos_embed = use_abs_pos_embed
|
||||
self.auto_pad = auto_pad
|
||||
|
||||
_patch_cfg = dict(
|
||||
img_size=img_size,
|
||||
in_channels=in_channels,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_cfg=dict(
|
||||
type='Conv2d', kernel_size=4, stride=4, padding=0, dilation=1),
|
||||
norm_cfg=dict(type='LN'),
|
||||
**patch_cfg)
|
||||
self.patch_embed = PatchEmbed(**_patch_cfg)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
patches_resolution = self.patch_embed.patches_resolution
|
||||
self.patches_resolution = patches_resolution
|
||||
|
||||
if self.use_abs_pos_embed:
|
||||
self.absolute_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches, self.embed_dims))
|
||||
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth
|
||||
total_depth = sum(self.depths)
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
self.stages = ModuleList()
|
||||
embed_dims = self.embed_dims
|
||||
input_resolution = patches_resolution
|
||||
for i, (depth,
|
||||
num_heads) in enumerate(zip(self.depths, self.num_heads)):
|
||||
if isinstance(stage_cfgs, Sequence):
|
||||
stage_cfg = stage_cfgs[i]
|
||||
else:
|
||||
stage_cfg = deepcopy(stage_cfgs)
|
||||
downsample = True if i < self.num_layers - 1 else False
|
||||
_stage_cfg = {
|
||||
'embed_dims': embed_dims,
|
||||
'depth': depth,
|
||||
'num_heads': num_heads,
|
||||
'downsample': downsample,
|
||||
'input_resolution': input_resolution,
|
||||
'drop_paths': dpr[:depth],
|
||||
'auto_pad': auto_pad,
|
||||
**stage_cfg
|
||||
}
|
||||
|
||||
stage = SwinBlockSequence(**_stage_cfg)
|
||||
self.stages.append(stage)
|
||||
|
||||
dpr = dpr[depth:]
|
||||
if downsample:
|
||||
embed_dims = stage.downsample.out_channels
|
||||
input_resolution = stage.downsample.output_resolution
|
||||
|
||||
if norm_cfg is not None:
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def init_weights(self):
|
||||
super(SwinTransformer, self).init_weights()
|
||||
|
||||
if self.use_abs_pos_embed:
|
||||
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.use_abs_pos_embed:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
for stage in self.stages:
|
||||
x = stage(x)
|
||||
|
||||
x = self.norm(x) if self.norm else x
|
||||
|
||||
return x.transpose(1, 2)
|
|
@ -1,6 +1,7 @@
|
|||
from .attention import ShiftWindowMSA
|
||||
from .augment.augments import Augments
|
||||
from .channel_shuffle import channel_shuffle
|
||||
from .embed import HybridEmbed, PatchEmbed
|
||||
from .embed import HybridEmbed, PatchEmbed, PatchMerging
|
||||
from .helpers import to_2tuple, to_3tuple, to_4tuple, to_ntuple
|
||||
from .inverted_residual import InvertedResidual
|
||||
from .make_divisible import make_divisible
|
||||
|
@ -8,6 +9,6 @@ from .se_layer import SELayer
|
|||
|
||||
__all__ = [
|
||||
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer',
|
||||
'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'Augments',
|
||||
'HybridEmbed', 'PatchEmbed'
|
||||
'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'PatchEmbed',
|
||||
'PatchMerging', 'HybridEmbed', 'Augments', 'ShiftWindowMSA'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,289 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn.bricks.registry import ATTENTION
|
||||
from mmcv.cnn.bricks.transformer import build_dropout
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
from mmcv.runner.base_module import BaseModule
|
||||
|
||||
from .helpers import to_2tuple
|
||||
|
||||
|
||||
class WindowMSA(BaseModule):
|
||||
"""Window based multi-head self-attention (W-MSA) module with relative
|
||||
position bias.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||
Defaults to True.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Defaults to None.
|
||||
attn_drop (float, optional): Dropout ratio of attention weight.
|
||||
Defaults to 0.
|
||||
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
window_size,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg)
|
||||
self.embed_dims = embed_dims
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_embed_dims = embed_dims // num_heads
|
||||
self.scale = qk_scale or head_embed_dims**-0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
||||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# About 2x faster than original impl
|
||||
Wh, Ww = self.window_size
|
||||
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
|
||||
rel_position_index = rel_index_coords + rel_index_coords.T
|
||||
rel_position_index = rel_position_index.flip(1).contiguous()
|
||||
self.register_buffer('relative_position_index', rel_position_index)
|
||||
|
||||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def init_weights(self):
|
||||
super(WindowMSA, self).init_weights()
|
||||
|
||||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
Args:
|
||||
|
||||
x (tensor): input features with shape of (num_windows*B, N, C)
|
||||
mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww,
|
||||
Wh*Ww), value should be between (-inf, 0].
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[
|
||||
2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1],
|
||||
-1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
||||
N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def double_step_seq(step1, len1, step2, len2):
|
||||
seq1 = torch.arange(0, step1 * len1, step1)
|
||||
seq2 = torch.arange(0, step2 * len2, step2)
|
||||
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
|
||||
|
||||
|
||||
@ATTENTION.register_module()
|
||||
class ShiftWindowMSA(BaseModule):
|
||||
"""Shift Window Multihead Self-Attention Module.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
input_resolution (Tuple[int, int]): The resolution of the input feature
|
||||
map.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): The height and width of the window.
|
||||
shift_size (int, optional): The shift step of each window towards
|
||||
right-bottom. If zero, act as regular window-msa. Defaults to 0.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||
Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Defaults to None.
|
||||
attn_drop (float, optional): Dropout ratio of attention weight.
|
||||
Defaults to 0.0.
|
||||
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
|
||||
dropout_layer (dict, optional): The dropout_layer used before output.
|
||||
Defaults to dict(type='DropPath', drop_prob=0.).
|
||||
auto_pad (bool, optional): Auto pad the feature map to be divisible by
|
||||
window_size, Defaults to False.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size,
|
||||
shift_size=0,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop=0,
|
||||
proj_drop=0,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=0.),
|
||||
auto_pad=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.input_resolution = input_resolution
|
||||
self.shift_size = shift_size
|
||||
self.window_size = window_size
|
||||
if min(self.input_resolution) <= self.window_size:
|
||||
# if window size is larger than input resolution, don't partition
|
||||
self.shift_size = 0
|
||||
self.window_size = min(self.input_resolution)
|
||||
|
||||
self.w_msa = WindowMSA(embed_dims, to_2tuple(self.window_size),
|
||||
num_heads, qkv_bias, qk_scale, attn_drop,
|
||||
proj_drop)
|
||||
|
||||
self.drop = build_dropout(dropout_layer)
|
||||
|
||||
H, W = self.input_resolution
|
||||
# Handle auto padding
|
||||
self.auto_pad = auto_pad
|
||||
if self.auto_pad:
|
||||
self.pad_r = (self.window_size -
|
||||
W % self.window_size) % self.window_size
|
||||
self.pad_b = (self.window_size -
|
||||
H % self.window_size) % self.window_size
|
||||
self.H_pad = H + self.pad_b
|
||||
self.W_pad = W + self.pad_r
|
||||
else:
|
||||
H_pad, W_pad = self.input_resolution
|
||||
assert H_pad % self.window_size + W_pad % self.window_size == 0,\
|
||||
f'input_resolution({self.input_resolution}) is not divisible '\
|
||||
f'by window_size({self.window_size}). Please check feature '\
|
||||
f'map shape or set `auto_pad=True`.'
|
||||
self.H_pad, self.W_pad = H_pad, W_pad
|
||||
self.pad_r, self.pad_b = 0, 0
|
||||
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
img_mask = torch.zeros((1, self.H_pad, self.W_pad, 1)) # 1 H W 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
# nW, window_size, window_size, 1
|
||||
mask_windows = self.window_partition(img_mask)
|
||||
mask_windows = mask_windows.view(
|
||||
-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
||||
float(-100.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
self.register_buffer('attn_mask', attn_mask)
|
||||
|
||||
def forward(self, query):
|
||||
H, W = self.input_resolution
|
||||
B, L, C = query.shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
query = query.view(B, H, W, C)
|
||||
|
||||
if self.pad_r or self.pad_b:
|
||||
query = F.pad(query, (0, 0, 0, self.pad_r, 0, self.pad_b))
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_query = torch.roll(
|
||||
query,
|
||||
shifts=(-self.shift_size, -self.shift_size),
|
||||
dims=(1, 2))
|
||||
else:
|
||||
shifted_query = query
|
||||
|
||||
# nW*B, window_size, window_size, C
|
||||
query_windows = self.window_partition(shifted_query)
|
||||
# nW*B, window_size*window_size, C
|
||||
query_windows = query_windows.view(-1, self.window_size**2, C)
|
||||
|
||||
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
|
||||
attn_windows = self.w_msa(query_windows, mask=self.attn_mask)
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size,
|
||||
self.window_size, C)
|
||||
|
||||
# B H' W' C
|
||||
shifted_x = self.window_reverse(attn_windows, self.H_pad, self.W_pad)
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(
|
||||
shifted_x,
|
||||
shifts=(self.shift_size, self.shift_size),
|
||||
dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
if self.pad_r or self.pad_b:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
def window_reverse(self, windows, H, W):
|
||||
window_size = self.window_size
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size,
|
||||
window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
def window_partition(self, x):
|
||||
B, H, W, C = x.shape
|
||||
window_size = self.window_size
|
||||
x = x.view(B, H // window_size, window_size, W // window_size,
|
||||
window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
||||
windows = windows.view(-1, window_size, window_size, C)
|
||||
return windows
|
|
@ -159,3 +159,92 @@ class HybridEmbed(BaseModule):
|
|||
x = x[-1]
|
||||
x = self.projection(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class PatchMerging(BaseModule):
|
||||
"""Merge patch feature map.
|
||||
|
||||
This layer use nn.Unfold to group feature map by kernel_size, and use norm
|
||||
and linear layer to embed grouped feature map.
|
||||
|
||||
Args:
|
||||
input_resolution (tuple): The size of input patch resolution.
|
||||
in_channels (int): The num of input channels.
|
||||
expansion_ratio (Number): Expansion ratio of output channels. The num
|
||||
of output channels is equal to int(expansion_ratio * in_channels).
|
||||
kernel_size (int | tuple, optional): the kernel size in the unfold
|
||||
layer. Defaults to 2.
|
||||
stride (int | tuple, optional): the stride of the sliding blocks in the
|
||||
unfold layer. Defaults to be equal with kernel_size.
|
||||
padding (int | tuple, optional): zero padding width in the unfold
|
||||
layer. Defaults to 0.
|
||||
dilation (int | tuple, optional): dilation parameter in the unfold
|
||||
layer. Defaults to 1.
|
||||
bias (bool, optional): Whether to add bias in linear layer or not.
|
||||
Defaults to False.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
Defaults to dict(type='LN').
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_resolution,
|
||||
in_channels,
|
||||
expansion_ratio,
|
||||
kernel_size=2,
|
||||
stride=None,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
bias=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
H, W = input_resolution
|
||||
self.input_resolution = input_resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = int(expansion_ratio * in_channels)
|
||||
|
||||
if stride is None:
|
||||
stride = kernel_size
|
||||
kernel_size = to_2tuple(kernel_size)
|
||||
stride = to_2tuple(stride)
|
||||
padding = to_2tuple(padding)
|
||||
dilation = to_2tuple(dilation)
|
||||
self.sampler = nn.Unfold(kernel_size, dilation, padding, stride)
|
||||
|
||||
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
|
||||
|
||||
if norm_cfg is not None:
|
||||
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
self.reduction = nn.Linear(sample_dim, self.out_channels, bias=bias)
|
||||
|
||||
# See https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
||||
H_out = (H + 2 * padding[0] - dilation[0] *
|
||||
(kernel_size[0] - 1) - 1) // stride[0] + 1
|
||||
W_out = (W + 2 * padding[1] - dilation[1] *
|
||||
(kernel_size[1] - 1) - 1) // stride[1] + 1
|
||||
self.output_resolution = (H_out, W_out)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, H*W, C
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
|
||||
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
|
||||
|
||||
# Use nn.Unfold to merge patch. About 25% faster than original method,
|
||||
# but need to modify pretrained model for compatibility
|
||||
x = self.sampler(x) # B, 4*C, H/2*W/2
|
||||
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
|
||||
|
||||
x = self.norm(x) if self.norm else x
|
||||
x = self.reduction(x)
|
||||
|
||||
return x
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmcls.models.utils.attention import ShiftWindowMSA, WindowMSA
|
||||
|
||||
|
||||
def get_relative_position_index(window_size):
|
||||
"""Method from original code of Swin-Transformer."""
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
# 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
# Wh*Ww, Wh*Ww, 2
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
return relative_position_index
|
||||
|
||||
|
||||
def test_window_msa():
|
||||
batch_size = 1
|
||||
num_windows = (4, 4)
|
||||
embed_dims = 96
|
||||
window_size = (7, 7)
|
||||
num_heads = 4
|
||||
attn = WindowMSA(
|
||||
embed_dims=embed_dims, window_size=window_size, num_heads=num_heads)
|
||||
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
|
||||
window_size[0] * window_size[1], embed_dims))
|
||||
|
||||
# test forward
|
||||
output = attn(inputs)
|
||||
assert output.shape == inputs.shape
|
||||
assert attn.relative_position_bias_table.shape == (
|
||||
(2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
||||
|
||||
# test relative_position_bias_table init
|
||||
attn.init_weights()
|
||||
assert abs(attn.relative_position_bias_table).sum() > 0
|
||||
|
||||
# test non-square window_size
|
||||
window_size = (6, 7)
|
||||
attn = WindowMSA(
|
||||
embed_dims=embed_dims, window_size=window_size, num_heads=num_heads)
|
||||
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
|
||||
window_size[0] * window_size[1], embed_dims))
|
||||
output = attn(inputs)
|
||||
assert output.shape == inputs.shape
|
||||
|
||||
# test relative_position_index
|
||||
expected_rel_pos_index = get_relative_position_index(window_size)
|
||||
assert (attn.relative_position_index == expected_rel_pos_index).all()
|
||||
|
||||
# test qkv_bias=True
|
||||
attn = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
window_size=window_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=True)
|
||||
assert attn.qkv.bias.shape == (embed_dims * 3, )
|
||||
|
||||
# test qkv_bias=False
|
||||
attn = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
window_size=window_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=False)
|
||||
assert attn.qkv.bias is None
|
||||
|
||||
# test default qk_scale
|
||||
attn = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
window_size=window_size,
|
||||
num_heads=num_heads,
|
||||
qk_scale=None)
|
||||
head_dims = embed_dims // num_heads
|
||||
assert np.isclose(attn.scale, head_dims**-0.5)
|
||||
|
||||
# test specified qk_scale
|
||||
attn = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
window_size=window_size,
|
||||
num_heads=num_heads,
|
||||
qk_scale=0.3)
|
||||
assert attn.scale == 0.3
|
||||
|
||||
# test attn_drop
|
||||
attn = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
window_size=window_size,
|
||||
num_heads=num_heads,
|
||||
attn_drop=1.0)
|
||||
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
|
||||
window_size[0] * window_size[1], embed_dims))
|
||||
# drop all attn output, output shuold be equal to proj.bias
|
||||
assert torch.allclose(attn(inputs), attn.proj.bias)
|
||||
|
||||
# test prob_drop
|
||||
attn = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
window_size=window_size,
|
||||
num_heads=num_heads,
|
||||
proj_drop=1.0)
|
||||
assert (attn(inputs) == 0).all()
|
||||
|
||||
|
||||
def test_shift_window_msa():
|
||||
batch_size = 1
|
||||
embed_dims = 96
|
||||
input_resolution = (14, 14)
|
||||
num_heads = 4
|
||||
window_size = 7
|
||||
|
||||
# test forward
|
||||
attn = ShiftWindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size)
|
||||
inputs = torch.rand(
|
||||
(batch_size, input_resolution[0] * input_resolution[1], embed_dims))
|
||||
output = attn(inputs)
|
||||
assert output.shape == (inputs.shape)
|
||||
assert attn.w_msa.relative_position_bias_table.shape == ((2 * window_size -
|
||||
1)**2, num_heads)
|
||||
|
||||
# test forward with shift_size
|
||||
attn = ShiftWindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=1)
|
||||
output = attn(inputs)
|
||||
assert output.shape == (inputs.shape)
|
||||
|
||||
# test relative_position_bias_table init
|
||||
attn.init_weights()
|
||||
assert abs(attn.w_msa.relative_position_bias_table).sum() > 0
|
||||
|
||||
# test dropout_layer
|
||||
attn = ShiftWindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=0.5))
|
||||
torch.manual_seed(0)
|
||||
output = attn(inputs)
|
||||
assert (output == 0).all()
|
||||
|
||||
# test auto_pad
|
||||
input_resolution = (19, 18)
|
||||
attn = ShiftWindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
auto_pad=True)
|
||||
assert attn.pad_r == 3
|
||||
assert attn.pad_b == 2
|
||||
|
||||
# test small input_resolution
|
||||
input_resolution = (5, 6)
|
||||
attn = ShiftWindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=3,
|
||||
auto_pad=True)
|
||||
assert attn.window_size == 5
|
||||
assert attn.shift_size == 0
|
|
@ -0,0 +1,56 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.models.utils import PatchMerging
|
||||
|
||||
|
||||
def cal_unfold_dim(dim, kernel_size, stride, padding=0, dilation=1):
|
||||
return (dim + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
|
||||
|
||||
|
||||
def test_patch_merging():
|
||||
settings = dict(
|
||||
input_resolution=(56, 56), in_channels=16, expansion_ratio=2)
|
||||
downsample = PatchMerging(**settings)
|
||||
|
||||
# test forward with wrong dims
|
||||
with pytest.raises(AssertionError):
|
||||
inputs = torch.rand((1, 16, 56 * 56))
|
||||
downsample(inputs)
|
||||
|
||||
# test patch merging forward
|
||||
inputs = torch.rand((1, 56 * 56, 16))
|
||||
out = downsample(inputs)
|
||||
assert downsample.output_resolution == (28, 28)
|
||||
assert out.shape == (1, 28 * 28, 32)
|
||||
|
||||
# test different kernel_size in each direction
|
||||
downsample = PatchMerging(kernel_size=(2, 3), **settings)
|
||||
out = downsample(inputs)
|
||||
expected_dim = cal_unfold_dim(56, 2, 2) * cal_unfold_dim(56, 3, 3)
|
||||
assert downsample.sampler.kernel_size == (2, 3)
|
||||
assert downsample.output_resolution == (cal_unfold_dim(56, 2, 2),
|
||||
cal_unfold_dim(56, 3, 3))
|
||||
assert out.shape == (1, expected_dim, 32)
|
||||
|
||||
# test default stride
|
||||
downsample = PatchMerging(kernel_size=6, **settings)
|
||||
assert downsample.sampler.stride == (6, 6)
|
||||
|
||||
# test stride=3
|
||||
downsample = PatchMerging(kernel_size=6, stride=3, **settings)
|
||||
out = downsample(inputs)
|
||||
assert downsample.sampler.stride == (3, 3)
|
||||
assert out.shape == (1, cal_unfold_dim(56, 6, stride=3)**2, 32)
|
||||
|
||||
# test padding
|
||||
downsample = PatchMerging(kernel_size=6, padding=2, **settings)
|
||||
out = downsample(inputs)
|
||||
assert downsample.sampler.padding == (2, 2)
|
||||
assert out.shape == (1, cal_unfold_dim(56, 6, 6, padding=2)**2, 32)
|
||||
|
||||
# test dilation
|
||||
downsample = PatchMerging(kernel_size=6, dilation=2, **settings)
|
||||
out = downsample(inputs)
|
||||
assert downsample.sampler.dilation == (2, 2)
|
||||
assert out.shape == (1, cal_unfold_dim(56, 6, 6, dilation=2)**2, 32)
|
|
@ -0,0 +1,144 @@
|
|||
from math import ceil
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.models.backbones import SwinTransformer
|
||||
|
||||
|
||||
def test_swin_transformer():
|
||||
"""Test Swin Transformer backbone."""
|
||||
with pytest.raises(AssertionError):
|
||||
# Swin Transformer arch string should be in
|
||||
SwinTransformer(arch='unknown')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Swin Transformer arch dict should include 'embed_dims',
|
||||
# 'depths' and 'num_head' keys.
|
||||
SwinTransformer(arch=dict(embed_dims=96, depths=[2, 2, 18, 2]))
|
||||
|
||||
# Test tiny arch forward
|
||||
model = SwinTransformer(arch='Tiny')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
output = model(imgs)
|
||||
assert output.shape == (1, 768, 49)
|
||||
|
||||
# Test small arch forward
|
||||
model = SwinTransformer(arch='small')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
output = model(imgs)
|
||||
assert output.shape == (1, 768, 49)
|
||||
|
||||
# Test base arch forward
|
||||
model = SwinTransformer(arch='B')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
output = model(imgs)
|
||||
assert output.shape == (1, 1024, 49)
|
||||
|
||||
# Test large arch forward
|
||||
model = SwinTransformer(arch='l')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
output = model(imgs)
|
||||
assert output.shape == (1, 1536, 49)
|
||||
|
||||
# Test base arch with window_size=12, image_size=384
|
||||
model = SwinTransformer(
|
||||
arch='base',
|
||||
img_size=384,
|
||||
stage_cfgs=dict(block_cfgs=dict(window_size=12)))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 384, 384)
|
||||
output = model(imgs)
|
||||
assert output.shape == (1, 1024, 144)
|
||||
|
||||
# Test small with use_abs_pos_embed = True
|
||||
model = SwinTransformer(arch='small', use_abs_pos_embed=True)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
assert model.absolute_pos_embed.shape == (1, 3136, 96)
|
||||
|
||||
# Test small with use_abs_pos_embed = False
|
||||
with pytest.raises(AttributeError):
|
||||
model = SwinTransformer(arch='small', use_abs_pos_embed=False)
|
||||
model.absolute_pos_embed
|
||||
|
||||
# Test small with auto_pad = True
|
||||
model = SwinTransformer(
|
||||
arch='small',
|
||||
auto_pad=True,
|
||||
stage_cfgs=dict(
|
||||
block_cfgs={'window_size': 7},
|
||||
downsample_cfg={
|
||||
'kernel_size': (3, 2),
|
||||
}))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
|
||||
# stage 1
|
||||
input_h = int(224 / 4 / 3)
|
||||
expect_h = ceil(input_h / 7) * 7
|
||||
input_w = int(224 / 4 / 2)
|
||||
expect_w = ceil(input_w / 7) * 7
|
||||
assert model.stages[1].blocks[0].attn.pad_b == expect_h - input_h
|
||||
assert model.stages[1].blocks[0].attn.pad_r == expect_w - input_w
|
||||
|
||||
# stage 2
|
||||
input_h = int(224 / 4 / 3 / 3)
|
||||
# input_h is smaller than window_size, shrink the window_size to input_h.
|
||||
expect_h = input_h
|
||||
input_w = int(224 / 4 / 2 / 2)
|
||||
expect_w = ceil(input_w / input_h) * input_h
|
||||
assert model.stages[2].blocks[0].attn.pad_b == expect_h - input_h
|
||||
assert model.stages[2].blocks[0].attn.pad_r == expect_w - input_w
|
||||
|
||||
# stage 3
|
||||
input_h = int(224 / 4 / 3 / 3 / 3)
|
||||
expect_h = input_h
|
||||
input_w = int(224 / 4 / 2 / 2 / 2)
|
||||
expect_w = ceil(input_w / input_h) * input_h
|
||||
assert model.stages[3].blocks[0].attn.pad_b == expect_h - input_h
|
||||
assert model.stages[3].blocks[0].attn.pad_r == expect_w - input_w
|
||||
|
||||
# Test small with auto_pad = False
|
||||
with pytest.raises(AssertionError):
|
||||
model = SwinTransformer(
|
||||
arch='small',
|
||||
auto_pad=False,
|
||||
stage_cfgs=dict(
|
||||
block_cfgs={'window_size': 7},
|
||||
downsample_cfg={
|
||||
'kernel_size': (3, 2),
|
||||
}))
|
||||
|
||||
# Test drop_path_rate decay
|
||||
model = SwinTransformer(
|
||||
arch='small',
|
||||
drop_path_rate=0.2,
|
||||
)
|
||||
depths = model.arch_settings['depths']
|
||||
pos = 0
|
||||
for i, depth in enumerate(depths):
|
||||
for j in range(depth):
|
||||
block = model.stages[i].blocks[j]
|
||||
expect_prob = 0.2 / (sum(depths) - 1) * pos
|
||||
assert np.isclose(block.ffn.dropout_layer.drop_prob, expect_prob)
|
||||
assert np.isclose(block.attn.drop.drop_prob, expect_prob)
|
||||
pos += 1
|
Loading…
Reference in New Issue