From 076ee10cacbda16d7c1bcd5514a64df8a826682b Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Thu, 1 Jul 2021 09:30:42 +0800 Subject: [PATCH] [Feature] Add swin-transformer model. (#271) * Add swin transformer archs S, B and L. * Add SwinTransformer configs * Add train config files of swin. * Align init method with original code * Use nn.Unfold to merge patch * Change all ConfigDict to dict * Add init_cfg for all subclasses of BaseModule. * Use mmcv version init function * Add Swin README * Use safer cfg copy method * Improve docstring and variable name. * Fix some difference in randaug Fix BGR bug, align scheduler config. Fix label smoothing parameter difference. * Fix missing droppath in attn * Fix bug of relative posititon table if window width is not equal to height. * Make `PatchMerging` more general, support kernel, stride, padding and dilation. * Rename `residual` to `identity` in attention and FFN. * Add `auto_pad` option to auto pad feature map * Improve docstring. * Fix bug in ShiftWMSA padding. * Remove unused `key` and `value` in ShiftWMSA * Move `PatchMerging` into utils and use common `PatchEmbed`. * Use latest `LinearClsHead`, train augments and label smooth settings. And remove original `SwinLinearClsHead`. * Mark some configs as "Evalution Only". * Remove useless comment in config * 1. Move ShiftWindowMSA and WindowMSA to `utils/attention.py` 2. Add docstrings of each module. 3. Fix some variables' names. 4. Other small improvement. * Add unit tests of swin-transformer and patchmerging. * Fix some bugs in unit tests. * Fix bug of rel_position_index if window is not square. * Make WindowMSA implicit, and add unit tests. * Add metafile.yml, update readme and model_zoo. --- README.md | 1 + README_zh-CN.md | 1 + .../datasets/imagenet_bs128_swin_224.py | 122 ++++++ .../datasets/imagenet_bs128_swin_384.py | 43 +++ .../models/swin_transformer/base_224.py | 22 ++ .../models/swin_transformer/base_384.py | 16 + .../models/swin_transformer/large_224.py | 12 + .../models/swin_transformer/large_384.py | 16 + .../models/swin_transformer/small_224.py | 23 ++ .../models/swin_transformer/tiny_224.py | 22 ++ .../schedules/imagenet_bs1024_adamw_swin.py | 30 ++ configs/swin_transformer/README.md | 41 ++ configs/swin_transformer/metafile.yml | 67 ++++ .../swin_base_224_imagenet.py | 6 + .../swin_base_384_imagenet.py | 7 + .../swin_large_224_imagenet.py | 7 + .../swin_large_384_imagenet.py | 7 + .../swin_small_224_imagenet.py | 6 + .../swin_tiny_224_imagenet.py | 6 + docs/model_zoo.md | 3 + mmcls/models/backbones/__init__.py | 4 +- mmcls/models/backbones/swin_transformer.py | 349 ++++++++++++++++++ mmcls/models/utils/__init__.py | 7 +- mmcls/models/utils/attention.py | 289 +++++++++++++++ mmcls/models/utils/embed.py | 89 +++++ tests/test_backbones/test_attention.py | 177 +++++++++ tests/test_backbones/test_embed.py | 56 +++ tests/test_backbones/test_swin_transformer.py | 144 ++++++++ 28 files changed, 1569 insertions(+), 4 deletions(-) create mode 100644 configs/_base_/datasets/imagenet_bs128_swin_224.py create mode 100644 configs/_base_/datasets/imagenet_bs128_swin_384.py create mode 100644 configs/_base_/models/swin_transformer/base_224.py create mode 100644 configs/_base_/models/swin_transformer/base_384.py create mode 100644 configs/_base_/models/swin_transformer/large_224.py create mode 100644 configs/_base_/models/swin_transformer/large_384.py create mode 100644 configs/_base_/models/swin_transformer/small_224.py create mode 100644 configs/_base_/models/swin_transformer/tiny_224.py create mode 100644 configs/_base_/schedules/imagenet_bs1024_adamw_swin.py create mode 100644 configs/swin_transformer/README.md create mode 100644 configs/swin_transformer/metafile.yml create mode 100644 configs/swin_transformer/swin_base_224_imagenet.py create mode 100644 configs/swin_transformer/swin_base_384_imagenet.py create mode 100644 configs/swin_transformer/swin_large_224_imagenet.py create mode 100644 configs/swin_transformer/swin_large_384_imagenet.py create mode 100644 configs/swin_transformer/swin_small_224_imagenet.py create mode 100644 configs/swin_transformer/swin_tiny_224_imagenet.py create mode 100644 mmcls/models/backbones/swin_transformer.py create mode 100644 mmcls/models/utils/attention.py create mode 100644 tests/test_backbones/test_attention.py create mode 100644 tests/test_backbones/test_embed.py create mode 100644 tests/test_backbones/test_swin_transformer.py diff --git a/README.md b/README.md index 53838691..4807a1e6 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,7 @@ Supported backbones: - [x] ShuffleNetV2 - [x] MobileNetV2 - [x] MobileNetV3 +- [x] Swin-Transformer ## Installation diff --git a/README_zh-CN.md b/README_zh-CN.md index 8da2aeec..82f03995 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -49,6 +49,7 @@ MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [O - [x] ShuffleNetV2 - [x] MobileNetV2 - [x] MobileNetV3 +- [x] Swin-Transformer ## 安装 diff --git a/configs/_base_/datasets/imagenet_bs128_swin_224.py b/configs/_base_/datasets/imagenet_bs128_swin_224.py new file mode 100644 index 00000000..7209cce7 --- /dev/null +++ b/configs/_base_/datasets/imagenet_bs128_swin_224.py @@ -0,0 +1,122 @@ +# dataset settings +dataset_type = 'ImageNet' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +policies = [ + dict(type='AutoContrast'), + dict(type='Equalize'), + dict(type='Invert'), + dict( + type='Rotate', + interpolation='bicubic', + magnitude_key='angle', + pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]), + magnitude_range=(0, 30)), + dict(type='Posterize', magnitude_key='bits', magnitude_range=(4, 0)), + dict(type='Solarize', magnitude_key='thr', magnitude_range=(256, 0)), + dict( + type='SolarizeAdd', + magnitude_key='magnitude', + magnitude_range=(0, 110)), + dict( + type='ColorTransform', + magnitude_key='magnitude', + magnitude_range=(0, 0.9)), + dict(type='Contrast', magnitude_key='magnitude', magnitude_range=(0, 0.9)), + dict( + type='Brightness', magnitude_key='magnitude', + magnitude_range=(0, 0.9)), + dict( + type='Sharpness', magnitude_key='magnitude', magnitude_range=(0, 0.9)), + dict( + type='Shear', + interpolation='bicubic', + magnitude_key='magnitude', + magnitude_range=(0, 0.3), + pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]), + direction='horizontal'), + dict( + type='Shear', + interpolation='bicubic', + magnitude_key='magnitude', + magnitude_range=(0, 0.3), + pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]), + direction='vertical'), + dict( + type='Translate', + interpolation='bicubic', + magnitude_key='magnitude', + magnitude_range=(0, 0.45), + pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]), + direction='horizontal'), + dict( + type='Translate', + interpolation='bicubic', + magnitude_key='magnitude', + magnitude_range=(0, 0.45), + pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]), + direction='vertical') +] + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + size=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies=policies, + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=img_norm_cfg['mean'][::-1], + fill_std=img_norm_cfg['std'][::-1]), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='ToTensor', keys=['gt_label']), + dict(type='Collect', keys=['img', 'gt_label']) +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + size=(256, -1), + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=224), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) +] +data = dict( + samples_per_gpu=128, + workers_per_gpu=8, + train=dict( + type=dataset_type, + data_prefix='data/imagenet/train', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_prefix='data/imagenet/val', + ann_file='data/imagenet/meta/val.txt', + pipeline=test_pipeline), + test=dict( + # replace `data/val` with `data/test` for standard test + type=dataset_type, + data_prefix='data/imagenet/val', + ann_file='data/imagenet/meta/val.txt', + pipeline=test_pipeline)) + +evaluation = dict(interval=10, metric='accuracy') diff --git a/configs/_base_/datasets/imagenet_bs128_swin_384.py b/configs/_base_/datasets/imagenet_bs128_swin_384.py new file mode 100644 index 00000000..5ac70cb0 --- /dev/null +++ b/configs/_base_/datasets/imagenet_bs128_swin_384.py @@ -0,0 +1,43 @@ +# dataset settings +dataset_type = 'ImageNet' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + size=384, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='ToTensor', keys=['gt_label']), + dict(type='Collect', keys=['img', 'gt_label']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', size=384, backend='pillow', interpolation='bicubic'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) +] +data = dict( + samples_per_gpu=128, + workers_per_gpu=8, + train=dict( + type=dataset_type, + data_prefix='data/imagenet/train', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_prefix='data/imagenet/val', + ann_file='data/imagenet/meta/val.txt', + pipeline=test_pipeline), + test=dict( + # replace `data/val` with `data/test` for standard test + type=dataset_type, + data_prefix='data/imagenet/val', + ann_file='data/imagenet/meta/val.txt', + pipeline=test_pipeline)) +evaluation = dict(interval=10, metric='accuracy') diff --git a/configs/_base_/models/swin_transformer/base_224.py b/configs/_base_/models/swin_transformer/base_224.py new file mode 100644 index 00000000..387a38c7 --- /dev/null +++ b/configs/_base_/models/swin_transformer/base_224.py @@ -0,0 +1,22 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='SwinTransformer', arch='base', img_size=224, drop_path_rate=0.5), + neck=dict(type='GlobalAveragePooling', dim=1), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=1024, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + cal_acc=False), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict(augments=[ + dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), + dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) + ])) diff --git a/configs/_base_/models/swin_transformer/base_384.py b/configs/_base_/models/swin_transformer/base_384.py new file mode 100644 index 00000000..9fb42893 --- /dev/null +++ b/configs/_base_/models/swin_transformer/base_384.py @@ -0,0 +1,16 @@ +# model settings +# Only for evaluation +model = dict( + type='ImageClassifier', + backbone=dict( + type='SwinTransformer', + arch='base', + img_size=384, + stage_cfg=dict(block_cfg=dict(window_size=12))), + neck=dict(type='GlobalAveragePooling', dim=1), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=1024, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5))) diff --git a/configs/_base_/models/swin_transformer/large_224.py b/configs/_base_/models/swin_transformer/large_224.py new file mode 100644 index 00000000..ee322b8b --- /dev/null +++ b/configs/_base_/models/swin_transformer/large_224.py @@ -0,0 +1,12 @@ +# model settings +# Only for evaluation +model = dict( + type='ImageClassifier', + backbone=dict(type='SwinTransformer', arch='large', img_size=224), + neck=dict(type='GlobalAveragePooling', dim=1), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=1536, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5))) diff --git a/configs/_base_/models/swin_transformer/large_384.py b/configs/_base_/models/swin_transformer/large_384.py new file mode 100644 index 00000000..a38b182c --- /dev/null +++ b/configs/_base_/models/swin_transformer/large_384.py @@ -0,0 +1,16 @@ +# model settings +# Only for evaluation +model = dict( + type='ImageClassifier', + backbone=dict( + type='SwinTransformer', + arch='large', + img_size=384, + stage_cfg=dict(block_cfg=dict(window_size=12))), + neck=dict(type='GlobalAveragePooling', dim=1), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=1536, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5))) diff --git a/configs/_base_/models/swin_transformer/small_224.py b/configs/_base_/models/swin_transformer/small_224.py new file mode 100644 index 00000000..f3dab785 --- /dev/null +++ b/configs/_base_/models/swin_transformer/small_224.py @@ -0,0 +1,23 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='SwinTransformer', arch='small', img_size=224, + drop_path_rate=0.3), + neck=dict(type='GlobalAveragePooling', dim=1), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=768, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + cal_acc=False), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict(augments=[ + dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), + dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) + ])) diff --git a/configs/_base_/models/swin_transformer/tiny_224.py b/configs/_base_/models/swin_transformer/tiny_224.py new file mode 100644 index 00000000..584586f4 --- /dev/null +++ b/configs/_base_/models/swin_transformer/tiny_224.py @@ -0,0 +1,22 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='SwinTransformer', arch='tiny', img_size=224, drop_path_rate=0.2), + neck=dict(type='GlobalAveragePooling', dim=1), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=768, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + cal_acc=False), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict(augments=[ + dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), + dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) + ])) diff --git a/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py b/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py new file mode 100644 index 00000000..1a523e44 --- /dev/null +++ b/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py @@ -0,0 +1,30 @@ +paramwise_cfg = dict( + norm_decay_mult=0.0, + bias_decay_mult=0.0, + custom_keys={ + '.absolute_pos_embed': dict(decay_mult=0.0), + '.relative_position_bias_table': dict(decay_mult=0.0) + }) + +# for batch in each gpu is 128, 8 gpu +# lr = 5e-4 * 128 * 8 / 512 = 0.001 +optimizer = dict( + type='AdamW', + lr=5e-4 * 128 * 8 / 512, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999), + paramwise_cfg=paramwise_cfg) +optimizer_config = dict(grad_clip=dict(max_norm=5.0)) + +# learning policy +lr_config = dict( + policy='CosineAnnealing', + by_epoch=False, + min_lr_ratio=1e-2, + warmup='linear', + warmup_ratio=1e-3, + warmup_iters=20 * 1252, + warmup_by_epoch=False) + +runner = dict(type='EpochBasedRunner', max_epochs=300) diff --git a/configs/swin_transformer/README.md b/configs/swin_transformer/README.md new file mode 100644 index 00000000..fa880227 --- /dev/null +++ b/configs/swin_transformer/README.md @@ -0,0 +1,41 @@ +# Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + +## Introduction + +[ALGORITHM] + +```latex +@article{liu2021Swin, + title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, + author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining}, + journal={arXiv preprint arXiv:2103.14030}, + year={2021} +} +``` + +## Pretrain model + +The pre-trained modles are converted from [model zoo of Swin Transformer](https://github.com/microsoft/Swin-Transformer#main-results-on-imagenet-with-pretrained-models). + +### ImageNet 1k + +| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download | +|:---------:|:------------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:--------:| +| Swin-T | ImageNet-1k | 224x224 | 28.29 | 4.36 | 81.18 | 95.52 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_tiny_patch4_window7_224-160bb0a5.pth)| +| Swin-S | ImageNet-1k | 224x224 | 49.61 | 8.52 | 83.21 | 96.25 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_small_patch4_window7_224-cc7a01c9.pth)| +| Swin-B | ImageNet-1k | 224x224 | 87.77 | 15.14 | 83.42 | 96.44 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224-4670dd19.pth)| +| Swin-B | ImageNet-1k | 384x384 | 87.90 | 44.49 | 84.49 | 96.95 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window12_384-02c598a4.pth)| +| Swin-B | ImageNet-22k | 224x224 | 87.77 | 15.14 | 85.16 | 97.50 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224_22kto1k-f967f799.pth)| +| Swin-B | ImageNet-22k | 384x384 | 87.90 | 44.49 | 86.44 | 98.05 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window12_384_22kto1k-d59b0d1d.pth)| +| Swin-L | ImageNet-22k | 224x224 | 196.53 | 34.04 | 86.24 | 97.88 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window7_224_22kto1k-5f0996db.pth)| +| Swin-L | ImageNet-22k | 384x384 | 196.74 | 100.04 | 87.25 | 98.25 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window12_384_22kto1k-0a40944b.pth)| + + +## Results and models + +### ImageNet +| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | +|:---------:|:------------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:----------:|:--------:| +| Swin-T | ImageNet-1k | 224x224 | 28.29 | 4.36 | 81.18 | 95.61 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_tiny_224_imagenet.py) |[model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.log.json)| +| Swin-S | ImageNet-1k | 224x224 | 49.61 | 8.52 | 83.02 | 96.29 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_small_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.log.json)| +| Swin-B | ImageNet-1k | 224x224 | 87.77 | 15.14 | 83.36 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_base_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.log.json)| diff --git a/configs/swin_transformer/metafile.yml b/configs/swin_transformer/metafile.yml new file mode 100644 index 00000000..fdd34591 --- /dev/null +++ b/configs/swin_transformer/metafile.yml @@ -0,0 +1,67 @@ +Collections: + - Name: Swin-Transformer + Metadata: + Training Data: ImageNet + Training Techniques: + - AdamW + - Weight Decay + Training Resources: 16x V100 GPUs + Epochs: 300 + Batch Size: 1024 + Architecture: + - Shift Window Multihead Self Attention + Paper: https://arxiv.org/pdf/2103.14030.pdf + README: configs/swin_transformer/README.md + +Models: +- Config: configs/swin_transformer/swin_tiny_224_imagenet.py + In Collection: Swin-Transformer + Metadata: + FLOPs: 4360000000 + Parameters: 28290000 + Training Data: ImageNet + Training Resources: 16x 1080 GPUs + Epochs: 300 + Batch Size: 1024 + Name: swin_tiny_224_imagenet + Results: + - Dataset: ImageNet + Metrics: + Top 1 Accuracy: 81.18 + Top 5 Accuracy: 95.61 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.pth +- Config: configs/swin_transformer/swin_small_224_imagenet.py + In Collection: Swin-Transformer + Metadata: + FLOPs: 8520000000 + Parameters: 48610000 + Training Data: ImageNet + Training Resources: 16x 1080 GPUs + Epochs: 300 + Batch Size: 1024 + Name: swin_small_224_imagenet + Results: + - Dataset: ImageNet + Metrics: + Top 1 Accuracy: 83.02 + Top 5 Accuracy: 96.29 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.pth +- Config: configs/swin_transformer/swin_base_224_imagenet.py + In Collection: Swin-Transformer + Metadata: + FLOPs: 15140000000 + Parameters: 87770000 + Training Data: ImageNet + Training Resources: 16x 1080 GPUs + Epochs: 300 + Batch Size: 1024 + Name: swin_base_224_imagenet + Results: + - Dataset: ImageNet + Metrics: + Top 1 Accuracy: 83.36 + Top 5 Accuracy: 96.44 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.pth diff --git a/configs/swin_transformer/swin_base_224_imagenet.py b/configs/swin_transformer/swin_base_224_imagenet.py new file mode 100644 index 00000000..e1ce91d9 --- /dev/null +++ b/configs/swin_transformer/swin_base_224_imagenet.py @@ -0,0 +1,6 @@ +_base_ = [ + '../_base_/models/swin_transformer/base_224.py', + '../_base_/datasets/imagenet_bs128_swin_224.py', + '../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../_base_/default_runtime.py' +] diff --git a/configs/swin_transformer/swin_base_384_imagenet.py b/configs/swin_transformer/swin_base_384_imagenet.py new file mode 100644 index 00000000..194de565 --- /dev/null +++ b/configs/swin_transformer/swin_base_384_imagenet.py @@ -0,0 +1,7 @@ +# Only for evaluation +_base_ = [ + '../_base_/models/swin_transformer/base_384.py', + '../_base_/datasets/imagenet_bs128_swin_384.py', + '../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../_base_/default_runtime.py' +] diff --git a/configs/swin_transformer/swin_large_224_imagenet.py b/configs/swin_transformer/swin_large_224_imagenet.py new file mode 100644 index 00000000..39c896bf --- /dev/null +++ b/configs/swin_transformer/swin_large_224_imagenet.py @@ -0,0 +1,7 @@ +# Only for evaluation +_base_ = [ + '../_base_/models/swin_transformer/large_224.py', + '../_base_/datasets/imagenet_bs128_swin_224.py', + '../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../_base_/default_runtime.py' +] diff --git a/configs/swin_transformer/swin_large_384_imagenet.py b/configs/swin_transformer/swin_large_384_imagenet.py new file mode 100644 index 00000000..6f4b7b7a --- /dev/null +++ b/configs/swin_transformer/swin_large_384_imagenet.py @@ -0,0 +1,7 @@ +# Only for evaluation +_base_ = [ + '../_base_/models/swin_transformer/large_384.py', + '../_base_/datasets/imagenet_bs128_swin_384.py', + '../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../_base_/default_runtime.py' +] diff --git a/configs/swin_transformer/swin_small_224_imagenet.py b/configs/swin_transformer/swin_small_224_imagenet.py new file mode 100644 index 00000000..1bf08afa --- /dev/null +++ b/configs/swin_transformer/swin_small_224_imagenet.py @@ -0,0 +1,6 @@ +_base_ = [ + '../_base_/models/swin_transformer/small_224.py', + '../_base_/datasets/imagenet_bs128_swin_224.py', + '../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../_base_/default_runtime.py' +] diff --git a/configs/swin_transformer/swin_tiny_224_imagenet.py b/configs/swin_transformer/swin_tiny_224_imagenet.py new file mode 100644 index 00000000..7f537864 --- /dev/null +++ b/configs/swin_transformer/swin_tiny_224_imagenet.py @@ -0,0 +1,6 @@ +_base_ = [ + '../_base_/models/swin_transformer/tiny_224.py', + '../_base_/datasets/imagenet_bs128_swin_224.py', + '../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../_base_/default_runtime.py' +] diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 1fcabf98..1c1baa41 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -40,6 +40,9 @@ The ResNet family models below are trained by standard data augmentations, i.e., | ViT-B/32* | 88.3 | 8.56 | 81.73 | 96.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/vision_transformer/vit_base_patch32_384_finetune_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/vit/vit_base_patch32_384.pth) | [log]() | | ViT-L/16* | 304.72 | 116.68 | 85.08 | 97.38 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/vision_transformer/vit_large_patch16_384_finetune_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/vit/vit_large_patch16_384.pth) | [log]() | | ViT-L/32* | 306.63 | 29.66 | 81.52 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/vision_transformer/vit_large_patch32_384_finetune_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/vit/vit_large_patch32_384.pth) | [log]() | +| Swin-Transformer tiny | 28.29 | 4.36 | 81.18 | 95.61 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_tiny_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.log.json)| +| Swin-Transformer small| 49.61 | 8.52 | 83.02 | 96.29 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_small_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.log.json)| +| Swin-Transformer base | 87.77 | 15.14 | 83.36 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_base_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.log.json)| Models with * are converted from other repos, others are trained by ourselves. diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index 62ed0f45..230d0ae9 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -11,11 +11,13 @@ from .seresnet import SEResNet from .seresnext import SEResNeXt from .shufflenet_v1 import ShuffleNetV1 from .shufflenet_v2 import ShuffleNetV2 +from .swin_transformer import SwinTransformer from .vgg import VGG from .vision_transformer import VisionTransformer __all__ = [ 'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', - 'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer' + 'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer', + 'SwinTransformer' ] diff --git a/mmcls/models/backbones/swin_transformer.py b/mmcls/models/backbones/swin_transformer.py new file mode 100644 index 00000000..6ea926b1 --- /dev/null +++ b/mmcls/models/backbones/swin_transformer.py @@ -0,0 +1,349 @@ +from copy import deepcopy +from typing import Sequence + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN +from mmcv.cnn.utils.weight_init import trunc_normal_ +from mmcv.runner.base_module import BaseModule, ModuleList + +from ..builder import BACKBONES +from ..utils import PatchEmbed, PatchMerging, ShiftWindowMSA +from .base_backbone import BaseBackbone + + +class SwinBlock(BaseModule): + """Swin Transformer block. + + Args: + embed_dims (int): Number of input channels. + input_resolution (Tuple[int, int]): The resolution of the input feature + map. + num_heads (int): Number of attention heads. + window_size (int, optional): The height and width of the window. + Defaults to 7. + shift (bool, optional): Shift the attention window or not. + Defaults to False. + ffn_ratio (float, optional): The expansion ratio of feedforward network + hidden layer channels. Defaults to 4. + drop_path (float, optional): The drop path rate after attention and + ffn. Defaults to 0. + attn_cfgs (dict, optional): The extra config of Shift Window-MSA. + Defaults to empty dict. + ffn_cfgs (dict, optional): The extra config of FFN. + Defaults to empty dict. + norm_cfg (dict, optional): The config of norm layers. + Defaults to dict(type='LN'). + auto_pad (bool, optional): Auto pad the feature map to be divisible by + window_size, Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + input_resolution, + num_heads, + window_size=7, + shift=False, + ffn_ratio=4., + drop_path=0., + attn_cfgs=dict(), + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + auto_pad=False, + init_cfg=None): + + super(SwinBlock, self).__init__(init_cfg) + + _attn_cfgs = { + 'embed_dims': embed_dims, + 'input_resolution': input_resolution, + 'num_heads': num_heads, + 'shift_size': window_size // 2 if shift else 0, + 'window_size': window_size, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'auto_pad': auto_pad, + **attn_cfgs + } + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = ShiftWindowMSA(**_attn_cfgs) + + _ffn_cfgs = { + 'embed_dims': embed_dims, + 'feedforward_channels': int(embed_dims * ffn_ratio), + 'num_fcs': 2, + 'ffn_drop': 0, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'act_cfg': dict(type='GELU'), + **ffn_cfgs + } + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN(**_ffn_cfgs) + + def forward(self, x): + identity = x + x = self.norm1(x) + x = self.attn(x) + x = x + identity + + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + return x + + +class SwinBlockSequence(BaseModule): + """Module with successive Swin Transformer blocks and downsample layer. + + Args: + embed_dims (int): Number of input channels. + input_resolution (Tuple[int, int]): The resolution of the input feature + map. + depth (int): Number of successive swin transformer blocks. + num_heads (int): Number of attention heads. + downsample (bool, optional): Downsample the output of blocks by patch + merging. Defaults to False. + downsample_cfg (dict, optional): The extra config of the patch merging + layer. Defaults to empty dict. + drop_paths (Sequence[float] | float, optional): The drop path rate in + each block. Defaults to 0. + block_cfgs (Sequence[dict] | dict, optional): The extra config of each + block. Defaults to empty dicts. + auto_pad (bool, optional): Auto pad the feature map to be divisible by + window_size, Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + input_resolution, + depth, + num_heads, + downsample=False, + downsample_cfg=dict(), + drop_paths=0., + block_cfgs=dict(), + auto_pad=False, + init_cfg=None): + super().__init__(init_cfg) + + if not isinstance(drop_paths, Sequence): + drop_paths = [drop_paths] * depth + + if not isinstance(block_cfgs, Sequence): + block_cfg = [deepcopy(block_cfgs) for _ in range(depth)] + + self.blocks = ModuleList() + for i in range(depth): + _block_cfg = { + 'embed_dims': embed_dims, + 'input_resolution': input_resolution, + 'num_heads': num_heads, + 'shift': False if i % 2 == 0 else True, + 'drop_path': drop_paths[i], + 'auto_pad': auto_pad, + **block_cfg[i] + } + block = SwinBlock(**_block_cfg) + self.blocks.append(block) + + if downsample: + _downsample_cfg = { + 'input_resolution': input_resolution, + 'in_channels': embed_dims, + 'expansion_ratio': 2, + 'norm_cfg': dict(type='LN'), + **downsample_cfg + } + self.downsample = PatchMerging(**_downsample_cfg) + else: + self.downsample = None + + def forward(self, x): + for block in self.blocks: + x = block(x) + + if self.downsample: + x = self.downsample(x) + return x + + +@BACKBONES.register_module() +class SwinTransformer(BaseBackbone): + """ Swin Transformer + A PyTorch implement of : `Swin Transformer: + Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/abs/2103.14030 + + Inspiration from + https://github.com/microsoft/Swin-Transformer + + Args: + arch (str | dict): Swin Transformer architecture + Defaults to 'T'. + img_size (int | tuple): The size of input image. + Defaults to 224. + in_channels (int): The num of input channels. + Defaults to 3. + drop_rate (float): Dropout rate after embedding. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. + Defaults to 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults to False. + auto_pad (bool): If True, auto pad feature map to fit window_size. + Defaults to False. + norm_cfg (dict, optional): Config dict for normalization layer at end + of backone. Defaults to dict(type='LN') + stage_cfgs (Sequence | dict, optional): Extra config dict for each + stage. Defaults to empty dict. + patch_cfg (dict, optional): Extra config dict for patch embedding. + Defaults to empty dict. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmcls.models import SwinTransformer + >>> import torch + >>> extra_config = dict( + >>> arch='tiny', + >>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3, + >>> 'expansion_ratio': 3}), + >>> auto_pad=True) + >>> self = SwinTransformer(**extra_config) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> output = self.forward(inputs) + >>> print(output.shape) + (1, 2592, 4) + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': 96, + 'depths': [2, 2, 6, 2], + 'num_heads': [3, 6, 12, 24]}), + **dict.fromkeys(['s', 'small'], + {'embed_dims': 96, + 'depths': [2, 2, 18, 2], + 'num_heads': [3, 6, 12, 24]}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': 128, + 'depths': [2, 2, 18, 2], + 'num_heads': [4, 8, 16, 32]}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': 192, + 'depths': [2, 2, 18, 2], + 'num_heads': [6, 12, 24, 48]}), + } # yapf: disable + + def __init__(self, + arch='T', + img_size=224, + in_channels=3, + drop_rate=0., + drop_path_rate=0.1, + use_abs_pos_embed=False, + auto_pad=False, + norm_cfg=dict(type='LN'), + stage_cfgs=dict(), + patch_cfg=dict(), + init_cfg=None): + super(SwinTransformer, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'num_head'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.num_layers = len(self.depths) + self.use_abs_pos_embed = use_abs_pos_embed + self.auto_pad = auto_pad + + _patch_cfg = dict( + img_size=img_size, + in_channels=in_channels, + embed_dims=self.embed_dims, + conv_cfg=dict( + type='Conv2d', kernel_size=4, stride=4, padding=0, dilation=1), + norm_cfg=dict(type='LN'), + **patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + if self.use_abs_pos_embed: + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dims)) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + # stochastic depth + total_depth = sum(self.depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + self.stages = ModuleList() + embed_dims = self.embed_dims + input_resolution = patches_resolution + for i, (depth, + num_heads) in enumerate(zip(self.depths, self.num_heads)): + if isinstance(stage_cfgs, Sequence): + stage_cfg = stage_cfgs[i] + else: + stage_cfg = deepcopy(stage_cfgs) + downsample = True if i < self.num_layers - 1 else False + _stage_cfg = { + 'embed_dims': embed_dims, + 'depth': depth, + 'num_heads': num_heads, + 'downsample': downsample, + 'input_resolution': input_resolution, + 'drop_paths': dpr[:depth], + 'auto_pad': auto_pad, + **stage_cfg + } + + stage = SwinBlockSequence(**_stage_cfg) + self.stages.append(stage) + + dpr = dpr[depth:] + if downsample: + embed_dims = stage.downsample.out_channels + input_resolution = stage.downsample.output_resolution + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + else: + self.norm = None + + def init_weights(self): + super(SwinTransformer, self).init_weights() + + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + + def forward(self, x): + x = self.patch_embed(x) + if self.use_abs_pos_embed: + x = x + self.absolute_pos_embed + x = self.drop_after_pos(x) + + for stage in self.stages: + x = stage(x) + + x = self.norm(x) if self.norm else x + + return x.transpose(1, 2) diff --git a/mmcls/models/utils/__init__.py b/mmcls/models/utils/__init__.py index cfa05c28..aaa1d4ef 100644 --- a/mmcls/models/utils/__init__.py +++ b/mmcls/models/utils/__init__.py @@ -1,6 +1,7 @@ +from .attention import ShiftWindowMSA from .augment.augments import Augments from .channel_shuffle import channel_shuffle -from .embed import HybridEmbed, PatchEmbed +from .embed import HybridEmbed, PatchEmbed, PatchMerging from .helpers import to_2tuple, to_3tuple, to_4tuple, to_ntuple from .inverted_residual import InvertedResidual from .make_divisible import make_divisible @@ -8,6 +9,6 @@ from .se_layer import SELayer __all__ = [ 'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer', - 'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'Augments', - 'HybridEmbed', 'PatchEmbed' + 'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'PatchEmbed', + 'PatchMerging', 'HybridEmbed', 'Augments', 'ShiftWindowMSA' ] diff --git a/mmcls/models/utils/attention.py b/mmcls/models/utils/attention.py new file mode 100644 index 00000000..a4e918b8 --- /dev/null +++ b/mmcls/models/utils/attention.py @@ -0,0 +1,289 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks.registry import ATTENTION +from mmcv.cnn.bricks.transformer import build_dropout +from mmcv.cnn.utils.weight_init import trunc_normal_ +from mmcv.runner.base_module import BaseModule + +from .helpers import to_2tuple + + +class WindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Defaults to None. + attn_drop (float, optional): Dropout ratio of attention weight. + Defaults to 0. + proj_drop (float, optional): Dropout ratio of output. Defaults to 0. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0., + init_cfg=None): + + super().__init__(init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # About 2x faster than original impl + Wh, Ww = self.window_size + rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww) + rel_position_index = rel_index_coords + rel_index_coords.T + rel_position_index = rel_position_index.flip(1).contiguous() + self.register_buffer('relative_position_index', rel_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def init_weights(self): + super(WindowMSA, self).init_weights() + + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww, + Wh*Ww), value should be between (-inf, 0]. + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @staticmethod + def double_step_seq(step1, len1, step2, len2): + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +@ATTENTION.register_module() +class ShiftWindowMSA(BaseModule): + """Shift Window Multihead Self-Attention Module. + + Args: + embed_dims (int): Number of input channels. + input_resolution (Tuple[int, int]): The resolution of the input feature + map. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. + shift_size (int, optional): The shift step of each window towards + right-bottom. If zero, act as regular window-msa. Defaults to 0. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Defaults to None. + attn_drop (float, optional): Dropout ratio of attention weight. + Defaults to 0.0. + proj_drop (float, optional): Dropout ratio of output. Defaults to 0. + dropout_layer (dict, optional): The dropout_layer used before output. + Defaults to dict(type='DropPath', drop_prob=0.). + auto_pad (bool, optional): Auto pad the feature map to be divisible by + window_size, Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + input_resolution, + num_heads, + window_size, + shift_size=0, + qkv_bias=True, + qk_scale=None, + attn_drop=0, + proj_drop=0, + dropout_layer=dict(type='DropPath', drop_prob=0.), + auto_pad=False, + init_cfg=None): + super().__init__(init_cfg) + + self.embed_dims = embed_dims + self.input_resolution = input_resolution + self.shift_size = shift_size + self.window_size = window_size + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, don't partition + self.shift_size = 0 + self.window_size = min(self.input_resolution) + + self.w_msa = WindowMSA(embed_dims, to_2tuple(self.window_size), + num_heads, qkv_bias, qk_scale, attn_drop, + proj_drop) + + self.drop = build_dropout(dropout_layer) + + H, W = self.input_resolution + # Handle auto padding + self.auto_pad = auto_pad + if self.auto_pad: + self.pad_r = (self.window_size - + W % self.window_size) % self.window_size + self.pad_b = (self.window_size - + H % self.window_size) % self.window_size + self.H_pad = H + self.pad_b + self.W_pad = W + self.pad_r + else: + H_pad, W_pad = self.input_resolution + assert H_pad % self.window_size + W_pad % self.window_size == 0,\ + f'input_resolution({self.input_resolution}) is not divisible '\ + f'by window_size({self.window_size}). Please check feature '\ + f'map shape or set `auto_pad=True`.' + self.H_pad, self.W_pad = H_pad, W_pad + self.pad_r, self.pad_b = 0, 0 + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, self.H_pad, self.W_pad, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # nW, window_size, window_size, 1 + mask_windows = self.window_partition(img_mask) + mask_windows = mask_windows.view( + -1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer('attn_mask', attn_mask) + + def forward(self, query): + H, W = self.input_resolution + B, L, C = query.shape + assert L == H * W, 'input feature has wrong size' + query = query.view(B, H, W, C) + + if self.pad_r or self.pad_b: + query = F.pad(query, (0, 0, 0, self.pad_r, 0, self.pad_b)) + + # cyclic shift + if self.shift_size > 0: + shifted_query = torch.roll( + query, + shifts=(-self.shift_size, -self.shift_size), + dims=(1, 2)) + else: + shifted_query = query + + # nW*B, window_size, window_size, C + query_windows = self.window_partition(shifted_query) + # nW*B, window_size*window_size, C + query_windows = query_windows.view(-1, self.window_size**2, C) + + # W-MSA/SW-MSA (nW*B, window_size*window_size, C) + attn_windows = self.w_msa(query_windows, mask=self.attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + + # B H' W' C + shifted_x = self.window_reverse(attn_windows, self.H_pad, self.W_pad) + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + + if self.pad_r or self.pad_b: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = self.drop(x) + return x + + def window_reverse(self, windows, H, W): + window_size = self.window_size + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + def window_partition(self, x): + B, H, W, C = x.shape + window_size = self.window_size + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows diff --git a/mmcls/models/utils/embed.py b/mmcls/models/utils/embed.py index 6c242435..c14453df 100644 --- a/mmcls/models/utils/embed.py +++ b/mmcls/models/utils/embed.py @@ -159,3 +159,92 @@ class HybridEmbed(BaseModule): x = x[-1] x = self.projection(x).flatten(2).transpose(1, 2) return x + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + This layer use nn.Unfold to group feature map by kernel_size, and use norm + and linear layer to embed grouped feature map. + + Args: + input_resolution (tuple): The size of input patch resolution. + in_channels (int): The num of input channels. + expansion_ratio (Number): Expansion ratio of output channels. The num + of output channels is equal to int(expansion_ratio * in_channels). + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Defaults to be equal with kernel_size. + padding (int | tuple, optional): zero padding width in the unfold + layer. Defaults to 0. + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Defaults to 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults to False. + norm_cfg (dict, optional): Config dict for normalization layer. + Defaults to dict(type='LN'). + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + input_resolution, + in_channels, + expansion_ratio, + kernel_size=2, + stride=None, + padding=0, + dilation=1, + bias=False, + norm_cfg=dict(type='LN'), + init_cfg=None): + super().__init__(init_cfg) + H, W = input_resolution + self.input_resolution = input_resolution + self.in_channels = in_channels + self.out_channels = int(expansion_ratio * in_channels) + + if stride is None: + stride = kernel_size + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + padding = to_2tuple(padding) + dilation = to_2tuple(dilation) + self.sampler = nn.Unfold(kernel_size, dilation, padding, stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + self.reduction = nn.Linear(sample_dim, self.out_channels, bias=bias) + + # See https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + H_out = (H + 2 * padding[0] - dilation[0] * + (kernel_size[0] - 1) - 1) // stride[0] + 1 + W_out = (W + 2 * padding[1] - dilation[1] * + (kernel_size[1] - 1) - 1) // stride[1] + 1 + self.output_resolution = (H_out, W_out) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + x = self.sampler(x) # B, 4*C, H/2*W/2 + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + + x = self.norm(x) if self.norm else x + x = self.reduction(x) + + return x diff --git a/tests/test_backbones/test_attention.py b/tests/test_backbones/test_attention.py new file mode 100644 index 00000000..cb8cbddb --- /dev/null +++ b/tests/test_backbones/test_attention.py @@ -0,0 +1,177 @@ +import numpy as np +import torch + +from mmcls.models.utils.attention import ShiftWindowMSA, WindowMSA + + +def get_relative_position_index(window_size): + """Method from original code of Swin-Transformer.""" + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + # 2, Wh*Ww, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + # Wh*Ww, Wh*Ww, 2 + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + return relative_position_index + + +def test_window_msa(): + batch_size = 1 + num_windows = (4, 4) + embed_dims = 96 + window_size = (7, 7) + num_heads = 4 + attn = WindowMSA( + embed_dims=embed_dims, window_size=window_size, num_heads=num_heads) + inputs = torch.rand((batch_size * num_windows[0] * num_windows[1], + window_size[0] * window_size[1], embed_dims)) + + # test forward + output = attn(inputs) + assert output.shape == inputs.shape + assert attn.relative_position_bias_table.shape == ( + (2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + + # test relative_position_bias_table init + attn.init_weights() + assert abs(attn.relative_position_bias_table).sum() > 0 + + # test non-square window_size + window_size = (6, 7) + attn = WindowMSA( + embed_dims=embed_dims, window_size=window_size, num_heads=num_heads) + inputs = torch.rand((batch_size * num_windows[0] * num_windows[1], + window_size[0] * window_size[1], embed_dims)) + output = attn(inputs) + assert output.shape == inputs.shape + + # test relative_position_index + expected_rel_pos_index = get_relative_position_index(window_size) + assert (attn.relative_position_index == expected_rel_pos_index).all() + + # test qkv_bias=True + attn = WindowMSA( + embed_dims=embed_dims, + window_size=window_size, + num_heads=num_heads, + qkv_bias=True) + assert attn.qkv.bias.shape == (embed_dims * 3, ) + + # test qkv_bias=False + attn = WindowMSA( + embed_dims=embed_dims, + window_size=window_size, + num_heads=num_heads, + qkv_bias=False) + assert attn.qkv.bias is None + + # test default qk_scale + attn = WindowMSA( + embed_dims=embed_dims, + window_size=window_size, + num_heads=num_heads, + qk_scale=None) + head_dims = embed_dims // num_heads + assert np.isclose(attn.scale, head_dims**-0.5) + + # test specified qk_scale + attn = WindowMSA( + embed_dims=embed_dims, + window_size=window_size, + num_heads=num_heads, + qk_scale=0.3) + assert attn.scale == 0.3 + + # test attn_drop + attn = WindowMSA( + embed_dims=embed_dims, + window_size=window_size, + num_heads=num_heads, + attn_drop=1.0) + inputs = torch.rand((batch_size * num_windows[0] * num_windows[1], + window_size[0] * window_size[1], embed_dims)) + # drop all attn output, output shuold be equal to proj.bias + assert torch.allclose(attn(inputs), attn.proj.bias) + + # test prob_drop + attn = WindowMSA( + embed_dims=embed_dims, + window_size=window_size, + num_heads=num_heads, + proj_drop=1.0) + assert (attn(inputs) == 0).all() + + +def test_shift_window_msa(): + batch_size = 1 + embed_dims = 96 + input_resolution = (14, 14) + num_heads = 4 + window_size = 7 + + # test forward + attn = ShiftWindowMSA( + embed_dims=embed_dims, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size) + inputs = torch.rand( + (batch_size, input_resolution[0] * input_resolution[1], embed_dims)) + output = attn(inputs) + assert output.shape == (inputs.shape) + assert attn.w_msa.relative_position_bias_table.shape == ((2 * window_size - + 1)**2, num_heads) + + # test forward with shift_size + attn = ShiftWindowMSA( + embed_dims=embed_dims, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=1) + output = attn(inputs) + assert output.shape == (inputs.shape) + + # test relative_position_bias_table init + attn.init_weights() + assert abs(attn.w_msa.relative_position_bias_table).sum() > 0 + + # test dropout_layer + attn = ShiftWindowMSA( + embed_dims=embed_dims, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + dropout_layer=dict(type='DropPath', drop_prob=0.5)) + torch.manual_seed(0) + output = attn(inputs) + assert (output == 0).all() + + # test auto_pad + input_resolution = (19, 18) + attn = ShiftWindowMSA( + embed_dims=embed_dims, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + auto_pad=True) + assert attn.pad_r == 3 + assert attn.pad_b == 2 + + # test small input_resolution + input_resolution = (5, 6) + attn = ShiftWindowMSA( + embed_dims=embed_dims, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=3, + auto_pad=True) + assert attn.window_size == 5 + assert attn.shift_size == 0 diff --git a/tests/test_backbones/test_embed.py b/tests/test_backbones/test_embed.py new file mode 100644 index 00000000..77862a6e --- /dev/null +++ b/tests/test_backbones/test_embed.py @@ -0,0 +1,56 @@ +import pytest +import torch + +from mmcls.models.utils import PatchMerging + + +def cal_unfold_dim(dim, kernel_size, stride, padding=0, dilation=1): + return (dim + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1 + + +def test_patch_merging(): + settings = dict( + input_resolution=(56, 56), in_channels=16, expansion_ratio=2) + downsample = PatchMerging(**settings) + + # test forward with wrong dims + with pytest.raises(AssertionError): + inputs = torch.rand((1, 16, 56 * 56)) + downsample(inputs) + + # test patch merging forward + inputs = torch.rand((1, 56 * 56, 16)) + out = downsample(inputs) + assert downsample.output_resolution == (28, 28) + assert out.shape == (1, 28 * 28, 32) + + # test different kernel_size in each direction + downsample = PatchMerging(kernel_size=(2, 3), **settings) + out = downsample(inputs) + expected_dim = cal_unfold_dim(56, 2, 2) * cal_unfold_dim(56, 3, 3) + assert downsample.sampler.kernel_size == (2, 3) + assert downsample.output_resolution == (cal_unfold_dim(56, 2, 2), + cal_unfold_dim(56, 3, 3)) + assert out.shape == (1, expected_dim, 32) + + # test default stride + downsample = PatchMerging(kernel_size=6, **settings) + assert downsample.sampler.stride == (6, 6) + + # test stride=3 + downsample = PatchMerging(kernel_size=6, stride=3, **settings) + out = downsample(inputs) + assert downsample.sampler.stride == (3, 3) + assert out.shape == (1, cal_unfold_dim(56, 6, stride=3)**2, 32) + + # test padding + downsample = PatchMerging(kernel_size=6, padding=2, **settings) + out = downsample(inputs) + assert downsample.sampler.padding == (2, 2) + assert out.shape == (1, cal_unfold_dim(56, 6, 6, padding=2)**2, 32) + + # test dilation + downsample = PatchMerging(kernel_size=6, dilation=2, **settings) + out = downsample(inputs) + assert downsample.sampler.dilation == (2, 2) + assert out.shape == (1, cal_unfold_dim(56, 6, 6, dilation=2)**2, 32) diff --git a/tests/test_backbones/test_swin_transformer.py b/tests/test_backbones/test_swin_transformer.py new file mode 100644 index 00000000..969ad456 --- /dev/null +++ b/tests/test_backbones/test_swin_transformer.py @@ -0,0 +1,144 @@ +from math import ceil + +import numpy as np +import pytest +import torch + +from mmcls.models.backbones import SwinTransformer + + +def test_swin_transformer(): + """Test Swin Transformer backbone.""" + with pytest.raises(AssertionError): + # Swin Transformer arch string should be in + SwinTransformer(arch='unknown') + + with pytest.raises(AssertionError): + # Swin Transformer arch dict should include 'embed_dims', + # 'depths' and 'num_head' keys. + SwinTransformer(arch=dict(embed_dims=96, depths=[2, 2, 18, 2])) + + # Test tiny arch forward + model = SwinTransformer(arch='Tiny') + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + output = model(imgs) + assert output.shape == (1, 768, 49) + + # Test small arch forward + model = SwinTransformer(arch='small') + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + output = model(imgs) + assert output.shape == (1, 768, 49) + + # Test base arch forward + model = SwinTransformer(arch='B') + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + output = model(imgs) + assert output.shape == (1, 1024, 49) + + # Test large arch forward + model = SwinTransformer(arch='l') + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + output = model(imgs) + assert output.shape == (1, 1536, 49) + + # Test base arch with window_size=12, image_size=384 + model = SwinTransformer( + arch='base', + img_size=384, + stage_cfgs=dict(block_cfgs=dict(window_size=12))) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 384, 384) + output = model(imgs) + assert output.shape == (1, 1024, 144) + + # Test small with use_abs_pos_embed = True + model = SwinTransformer(arch='small', use_abs_pos_embed=True) + model.init_weights() + model.train() + + assert model.absolute_pos_embed.shape == (1, 3136, 96) + + # Test small with use_abs_pos_embed = False + with pytest.raises(AttributeError): + model = SwinTransformer(arch='small', use_abs_pos_embed=False) + model.absolute_pos_embed + + # Test small with auto_pad = True + model = SwinTransformer( + arch='small', + auto_pad=True, + stage_cfgs=dict( + block_cfgs={'window_size': 7}, + downsample_cfg={ + 'kernel_size': (3, 2), + })) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + + # stage 1 + input_h = int(224 / 4 / 3) + expect_h = ceil(input_h / 7) * 7 + input_w = int(224 / 4 / 2) + expect_w = ceil(input_w / 7) * 7 + assert model.stages[1].blocks[0].attn.pad_b == expect_h - input_h + assert model.stages[1].blocks[0].attn.pad_r == expect_w - input_w + + # stage 2 + input_h = int(224 / 4 / 3 / 3) + # input_h is smaller than window_size, shrink the window_size to input_h. + expect_h = input_h + input_w = int(224 / 4 / 2 / 2) + expect_w = ceil(input_w / input_h) * input_h + assert model.stages[2].blocks[0].attn.pad_b == expect_h - input_h + assert model.stages[2].blocks[0].attn.pad_r == expect_w - input_w + + # stage 3 + input_h = int(224 / 4 / 3 / 3 / 3) + expect_h = input_h + input_w = int(224 / 4 / 2 / 2 / 2) + expect_w = ceil(input_w / input_h) * input_h + assert model.stages[3].blocks[0].attn.pad_b == expect_h - input_h + assert model.stages[3].blocks[0].attn.pad_r == expect_w - input_w + + # Test small with auto_pad = False + with pytest.raises(AssertionError): + model = SwinTransformer( + arch='small', + auto_pad=False, + stage_cfgs=dict( + block_cfgs={'window_size': 7}, + downsample_cfg={ + 'kernel_size': (3, 2), + })) + + # Test drop_path_rate decay + model = SwinTransformer( + arch='small', + drop_path_rate=0.2, + ) + depths = model.arch_settings['depths'] + pos = 0 + for i, depth in enumerate(depths): + for j in range(depth): + block = model.stages[i].blocks[j] + expect_prob = 0.2 / (sum(depths) - 1) * pos + assert np.isclose(block.ffn.dropout_layer.drop_prob, expect_prob) + assert np.isclose(block.attn.drop.drop_prob, expect_prob) + pos += 1