[WIP] Add Swin Transformer (#511)

* add Swin Transformer

* add Swin Transformer

* fixed import

* Add some swin training settings.

* Fix some filename error.

* Fix attribute name: pretrain -> pretrained

* Upload mmcls implementation of swin transformer.

* Refactor Swin Transformer to follow mmcls style.

* Refactor init_weigths of swin_transformer.py

* Fix lint

* Match inference precision

* Add some comments

* Add swin_convert to load official style ckpt

* Remove arg: auto_pad

* 1. Complete comments for each block;

2. Correct weight convert function;

3. Fix the pad of Patch Merging;

* Clean function args.

* Fix vit unit test.

* 1. Add swin transformer unit tests;

2. Fix some pad bug;

3. Modify config to adapt new swin implementation;

* Modify config arg

* Update readme.md of swin

* Fix config arg error and Add some swin benchmark msg.

* Add MeM and ms test content for readme.md of swin transformer.

* Fix doc string of swin module

* 1. Register swin transformer to model list;

2. Modify pth url which keep meta attribute;

* Update swin.py

* Merge config settings.

* Modify config style.

* Update README.md

Add ViT link

* Modify main readme.md

Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
Co-authored-by: sennnnn <201730271412@mail.scut.edu.cn>
Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
pull/1801/head
Ze Liu 2021-07-01 23:41:55 +08:00 committed by GitHub
parent 5245edb0a0
commit 214d083cce
19 changed files with 1242 additions and 97 deletions

View File

@ -59,11 +59,12 @@ Supported backbones:
- [x] ResNet (CVPR'2016) - [x] ResNet (CVPR'2016)
- [x] ResNeXt (CVPR'2017) - [x] ResNeXt (CVPR'2017)
- [x] [HRNet (CVPR'2019)](configs/hrnet/README.md) - [x] [HRNet (CVPR'2019)](configs/hrnet)
- [x] [ResNeSt (ArXiv'2020)](configs/resnest/README.md) - [x] [ResNeSt (ArXiv'2020)](configs/resnest)
- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2/README.md) - [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2)
- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3/README.md) - [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3)
- [x] [Vision Transformer (ICLR'2021)] - [x] [Vision Transformer (ICLR'2021)](configs/vit)
- [x] [Swin Transformer (arXiV'2021)](configs/swin)
Supported methods: Supported methods:
@ -71,7 +72,7 @@ Supported methods:
- [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet) - [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet)
- [x] [PSPNet (CVPR'2017)](configs/pspnet) - [x] [PSPNet (CVPR'2017)](configs/pspnet)
- [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3) - [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3)
- [x] [Mixed Precision (FP16) Training (ArXiv'2017)](configs/fp16/README.md) - [x] [Mixed Precision (FP16) Training (ArXiv'2017)](configs/fp16)
- [x] [PSANet (ECCV'2018)](configs/psanet) - [x] [PSANet (ECCV'2018)](configs/psanet)
- [x] [DeepLabV3+ (CVPR'2018)](configs/deeplabv3plus) - [x] [DeepLabV3+ (CVPR'2018)](configs/deeplabv3plus)
- [x] [UPerNet (ECCV'2018)](configs/upernet) - [x] [UPerNet (ECCV'2018)](configs/upernet)

View File

@ -58,18 +58,20 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] ResNet (CVPR'2016) - [x] ResNet (CVPR'2016)
- [x] ResNeXt (CVPR'2017) - [x] ResNeXt (CVPR'2017)
- [x] [HRNet (CVPR'2019)](configs/hrnet/README.md) - [x] [HRNet (CVPR'2019)](configs/hrnet)
- [x] [ResNeSt (ArXiv'2020)](configs/resnest/README.md) - [x] [ResNeSt (ArXiv'2020)](configs/resnest)
- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2/README.md) - [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2)
- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3/README.md) - [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3)
- [x] [Vision Transformer (ICLR'2021)] - [x] [Vision Transformer (ICLR'2021)](configs/vit)
- [x] [Swin Transformer (arXiV'2021)](configs/swin)
已支持的算法: 已支持的算法:
- [x] [FCN (CVPR'2015/TPAMI'2017)](configs/fcn) - [x] [FCN (CVPR'2015/TPAMI'2017)](configs/fcn)
- [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet)
- [x] [PSPNet (CVPR'2017)](configs/pspnet) - [x] [PSPNet (CVPR'2017)](configs/pspnet)
- [x] [DeepLabV3 (CVPR'2017)](configs/deeplabv3) - [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3)
- [x] [Mixed Precision (FP16) Training (ArXiv'2017)](configs/fp16/README.md) - [x] [Mixed Precision (FP16) Training (ArXiv'2017)](configs/fp16)
- [x] [PSANet (ECCV'2018)](configs/psanet) - [x] [PSANet (ECCV'2018)](configs/psanet)
- [x] [DeepLabV3+ (CVPR'2018)](configs/deeplabv3plus) - [x] [DeepLabV3+ (CVPR'2018)](configs/deeplabv3plus)
- [x] [UPerNet (ECCV'2018)](configs/upernet) - [x] [UPerNet (ECCV'2018)](configs/upernet)

View File

@ -0,0 +1,55 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
backbone_norm_cfg = dict(type='LN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='SwinTransformer',
pretrain_img_size=224,
embed_dims=96,
patch_size=4,
window_size=7,
mlp_ratio=4,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
strides=(4, 2, 2, 2),
out_indices=(0, 1, 2, 3),
qkv_bias=True,
qk_scale=None,
patch_norm=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
use_abs_pos_embed=False,
act_cfg=dict(type='GELU'),
norm_cfg=backbone_norm_cfg,
pretrain_style='official'),
decode_head=dict(
type='UPerHead',
in_channels=[96, 192, 384, 768],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=384,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,27 @@
# Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
## Introduction
[ALGORITHM]
```latex
@article{liu2021Swin,
title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
journal={arXiv preprint arXiv:2103.14030},
year={2021}
}
```
## Results and models
### ADE20K
| Method | Backbone | Crop Size | pretrain | pretrain img size | Batch Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ---------- | ------- | -------- | --- | --- | -------------- | ----- | ------------: | -------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| UperNet | Swin-T | 512x512 | ImageNet-1K | 224x224 | 16 | 160000 | 5.02 | 21.06 | 44.41 | 45.79 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210531_112542-e380ad3e.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210531_112542.log.json) |
| UperNet | Swin-S | 512x512 | ImageNet-1K | 224x224 | 16 | 160000 | 6.17 | 14.72 | 47.72 | 49.24 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192015-ee2fff1c.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192015.log.json) |
| UperNet | Swin-B | 512x512 | ImageNet-1K | 224x224 | 16 | 160000 | 7.61 | 12.65 | 47.99 | 49.57 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192340-593b0e13.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192340.log.json) |
| UperNet | Swin-B | 512x512 | ImageNet-22K | 224x224 | 16 | 160000 | - | - | 50.31 | 51.9 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K_20210526_211650-762e2178.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K_20210526_211650.log.json) |
| UperNet | Swin-B | 512x512 | ImageNet-1K | 384x384 | 16 | 160000 | 8.52 | 12.10 | 48.35 | 49.65 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K_20210531_132020-05b22ea4.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K_20210531_132020.log.json) |
| UperNet | Swin-B | 512x512 | ImageNet-22K | 384x384 | 16 | 160000 | - | - | 50.76 | 52.4 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K_20210531_125459-429057bf.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K_20210531_125459.log.json) |

View File

@ -0,0 +1,15 @@
_base_ = [
'upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_'
'pretrain_224x224_1K.py'
]
model = dict(
pretrained=\
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth', # noqa
backbone=dict(
pretrain_img_size=384,
embed_dims=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=12),
decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150),
auxiliary_head=dict(in_channels=512, num_classes=150))

View File

@ -0,0 +1,8 @@
_base_ = [
'./upernet_swin_base_patch4_window12_512x512_160k_ade20k_'
'pretrain_384x384_1K.py'
]
model = dict(
pretrained=\
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', # noqa
)

View File

@ -0,0 +1,13 @@
_base_ = [
'./upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_'
'pretrain_224x224_1K.py'
]
model = dict(
pretrained=\
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth', # noqa
backbone=dict(
embed_dims=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32]),
decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150),
auxiliary_head=dict(in_channels=512, num_classes=150))

View File

@ -0,0 +1,8 @@
_base_ = [
'./upernet_swin_base_patch4_window7_512x512_160k_ade20k_'
'pretrain_224x224_1K.py'
]
model = dict(
pretrained=\
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', # noqa
)

View File

@ -0,0 +1,17 @@
_base_ = [
'./upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_'
'pretrain_224x224_1K.py'
]
model = dict(
pretrained=\
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', # noqa
backbone=dict(
depths=[2, 2, 18, 2]),
decode_head=dict(
in_channels=[96, 192, 384, 768],
num_classes=150
),
auxiliary_head=dict(
in_channels=384,
num_classes=150
))

View File

@ -0,0 +1,46 @@
_base_ = [
'../_base_/models/upernet_swin.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
model = dict(
pretrained=\
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', # noqa
backbone=dict(
embed_dims=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
use_abs_pos_embed=False,
drop_path_rate=0.3,
patch_norm=True,
pretrain_style='official'),
decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
auxiliary_head=dict(in_channels=384, num_classes=150))
# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)

View File

@ -6,11 +6,12 @@ from .mobilenet_v3 import MobileNetV3
from .resnest import ResNeSt from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1c, ResNetV1d from .resnet import ResNet, ResNetV1c, ResNetV1d
from .resnext import ResNeXt from .resnext import ResNeXt
from .swin import SwinTransformer
from .unet import UNet from .unet import UNet
from .vit import VisionTransformer from .vit import VisionTransformer
__all__ = [ __all__ = [
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer' 'VisionTransformer', 'SwinTransformer'
] ]

View File

@ -0,0 +1,778 @@
import warnings
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer, trunc_normal_init
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmcv.cnn.utils.weight_init import constant_init
from mmcv.runner import _load_checkpoint
from mmcv.runner.base_module import BaseModule, ModuleList
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
from torch.nn.modules.utils import _pair as to_2tuple
from ...utils import get_root_logger
from ..builder import BACKBONES
from ..utils import PatchEmbed, swin_convert
class PatchMerging(BaseModule):
"""Merge patch feature map.
This layer use nn.Unfold to group feature map by kernel_size, and use norm
and linear layer to embed grouped feature map.
Args:
in_channels (int): The num of input channels.
out_channels (int): The num of output channels.
stride (int | tuple): the stride of the sliding length in the
unfold layer. Defaults: 2. (Default to be equal with kernel_size).
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults: False.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults: dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Defaults: None.
"""
def __init__(self,
in_channels,
out_channels,
stride=2,
bias=False,
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.sampler = nn.Unfold(
kernel_size=stride, dilation=1, padding=0, stride=stride)
sample_dim = stride**2 * in_channels
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
def forward(self, x, hw_shape):
"""
x: x.shape -> [B, H*W, C]
hw_shape: (H, W)
"""
B, L, C = x.shape
H, W = hw_shape
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
# stride is fixed to be equal to kernel_size.
if (H % self.stride != 0) or (W % self.stride != 0):
x = F.pad(x, (0, W % self.stride, 0, H % self.stride))
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
x = self.sampler(x) # B, 4*C, H/2*W/2
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
x = self.norm(x) if self.norm else x
x = self.reduction(x)
down_hw_shape = (H + 1) // 2, (W + 1) // 2
return x, down_hw_shape
@ATTENTION.register_module()
class WindowMSA(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative
position bias.
Args:
embed_dims (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.0
init_cfg (dict | None, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size,
qkv_bias=True,
qk_scale=None,
attn_drop_rate=0.,
proj_drop_rate=0.,
init_cfg=None):
super().__init__()
self.embed_dims = embed_dims
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.scale = qk_scale or head_embed_dims**-0.5
self.init_cfg = init_cfg
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# About 2x faster than original impl
Wh, Ww = self.window_size
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
rel_position_index = rel_index_coords + rel_index_coords.T
rel_position_index = rel_position_index.flip(1).contiguous()
self.register_buffer('relative_position_index', rel_position_index)
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_rate)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop_rate)
self.softmax = nn.Softmax(dim=-1)
def init_weights(self):
trunc_normal_init(self.relative_position_bias_table, std=0.02)
def forward(self, x, mask=None):
"""
Args:
x (tensor): input features with shape of (num_windows*B, N, C)
mask (tensor | None, Optional): mask with shape of (num_windows,
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@staticmethod
def double_step_seq(step1, len1, step2, len2):
seq1 = torch.arange(0, step1 * len1, step1)
seq2 = torch.arange(0, step2 * len2, step2)
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
@ATTENTION.register_module()
class ShiftWindowMSA(BaseModule):
"""Shift Window Multihead Self-Attention Module.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window.
shift_size (int, optional): The shift step of each window towards
right-bottom. If zero, act as regular window-msa. Defaults to 0.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Defaults: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Defaults: 0.
proj_drop_rate (float, optional): Dropout ratio of output.
Defaults: 0.
dropout_layer (dict, optional): The dropout_layer used before output.
Defaults: dict(type='DropPath', drop_prob=0.).
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size,
shift_size=0,
qkv_bias=True,
qk_scale=None,
attn_drop_rate=0,
proj_drop_rate=0,
dropout_layer=dict(type='DropPath', drop_prob=0.),
init_cfg=None):
super().__init__(init_cfg)
self.window_size = window_size
self.shift_size = shift_size
assert 0 <= self.shift_size < self.window_size
self.w_msa = WindowMSA(
embed_dims=embed_dims,
num_heads=num_heads,
window_size=to_2tuple(window_size),
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop_rate=attn_drop_rate,
proj_drop_rate=proj_drop_rate,
init_cfg=None)
self.drop = build_dropout(dropout_layer)
def forward(self, query, hw_shape):
B, L, C = query.shape
H, W = hw_shape
assert L == H * W, 'input feature has wrong size'
query = query.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
H_pad, W_pad = query.shape[1], query.shape[2]
# cyclic shift
if self.shift_size > 0:
shifted_query = torch.roll(
query,
shifts=(-self.shift_size, -self.shift_size),
dims=(1, 2))
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, H_pad, W_pad, 1),
device=query.device) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# nW, window_size, window_size, 1
mask_windows = self.window_partition(img_mask)
mask_windows = mask_windows.view(
-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-100.0)).masked_fill(
attn_mask == 0, float(0.0))
else:
shifted_query = query
attn_mask = None
# nW*B, window_size, window_size, C
query_windows = self.window_partition(shifted_query)
# nW*B, window_size*window_size, C
query_windows = query_windows.view(-1, self.window_size**2, C)
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
attn_windows = self.w_msa(query_windows, mask=attn_mask)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size,
self.window_size, C)
# B H' W' C
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(
shifted_x,
shifts=(self.shift_size, self.shift_size),
dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = self.drop(x)
return x
def window_reverse(self, windows, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
window_size = self.window_size
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
def window_partition(self, x):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
window_size = self.window_size
x = x.view(B, H // window_size, window_size, W // window_size,
window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows
class SwinBlock(BaseModule):
""""
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
window size (int, optional): The local window scale. Default: 7.
shift (bool): whether to shift window or not. Default False.
qkv_bias (int, optional): enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2.
act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of nomalization.
Default: dict(type='LN').
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
window_size=7,
shift=False,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super(SwinBlock, self).__init__()
self.init_cfg = init_cfg
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = ShiftWindowMSA(
embed_dims=embed_dims,
num_heads=num_heads,
window_size=window_size,
shift_size=window_size // 2 if shift else 0,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop_rate=attn_drop_rate,
proj_drop_rate=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
init_cfg=None)
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=2,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg,
add_identity=True,
init_cfg=None)
def forward(self, x, hw_shape):
identity = x
x = self.norm1(x)
x = self.attn(x, hw_shape)
x = x + identity
identity = x
x = self.norm2(x)
x = self.ffn(x, identity=identity)
return x
class SwinBlockSequence(BaseModule):
"""Implements one stage in Swin Transformer.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
depth (int): The number of blocks in this stage.
window size (int): The local window scale. Default: 7.
qkv_bias (int): enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2.
downsample (BaseModule | None, optional): The downsample operation
module. Default: None.
act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of nomalization.
Default: dict(type='LN').
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
depth,
window_size=7,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
downsample=None,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__()
self.init_cfg = init_cfg
drop_path_rate = drop_path_rate if isinstance(
drop_path_rate,
list) else [deepcopy(drop_path_rate) for _ in range(depth)]
self.blocks = ModuleList()
for i in range(depth):
block = SwinBlock(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=feedforward_channels,
window_size=window_size,
shift=False if i % 2 == 0 else True,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate[i],
act_cfg=act_cfg,
norm_cfg=norm_cfg,
init_cfg=None)
self.blocks.append(block)
self.downsample = downsample
def forward(self, x, hw_shape):
for block in self.blocks:
x = block(x, hw_shape)
if self.downsample:
x_down, down_hw_shape = self.downsample(x, hw_shape)
return x_down, down_hw_shape, x, hw_shape
else:
return x, hw_shape, x, hw_shape
@BACKBONES.register_module()
class SwinTransformer(BaseModule):
""" Swin Transformer
A PyTorch implement of : `Swin Transformer:
Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/abs/2103.14030
Inspiration from
https://github.com/microsoft/Swin-Transformer
Args:
pretrain_img_size (int | tuple[int]): The size of input image when
pretrain. Defaults: 224.
in_channels (int): The num of input channels.
Defaults: 3.
embed_dims (int): The feature dimension. Default: 96.
patch_size (int | tuple[int]): Patch size. Default: 4.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
Default: 4.
depths (tuple[int]): Depths of each Swin Transformer stage.
Default: (2, 2, 6, 2).
num_heads (tuple[int]): Parallel attention heads of each Swin
Transformer stage. Default: (3, 6, 12, 24).
strides (tuple[int]): The patch merging or patch embedding stride of
each Swin Transformer stage. (In swin, we set kernel size equal to
stride.) Default: (4, 2, 2, 2).
out_indices (tuple[int]): Output from which stages.
Default: (0, 1, 2, 3).
qkv_bias (bool, optional): If True, add a learnable bias to query, key,
value. Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
patch_norm (bool): If add a norm layer for patch embed and patch
merging. Default: True.
drop_rate (float): Dropout rate. Defaults: 0.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults: False.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LN').
norm_cfg (dict): Config dict for normalization layer at
output of backone. Defaults: dict(type='LN').
pretrain_style (str): Choose to use official or mmcls pretrain weights.
Default: official.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
pretrain_img_size=224,
in_channels=3,
embed_dims=96,
patch_size=4,
window_size=7,
mlp_ratio=4,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
strides=(4, 2, 2, 2),
out_indices=(0, 1, 2, 3),
qkv_bias=True,
qk_scale=None,
patch_norm=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
use_abs_pos_embed=False,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
pretrain_style='official',
pretrained=None,
init_cfg=None):
super(SwinTransformer, self).__init__()
if isinstance(pretrain_img_size, int):
pretrain_img_size = to_2tuple(pretrain_img_size)
elif isinstance(pretrain_img_size, tuple):
if len(pretrain_img_size) == 1:
pretrain_img_size = to_2tuple(pretrain_img_size[0])
assert len(pretrain_img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(pretrain_img_size)}'
assert pretrain_style in ['official', 'mmcls'], 'We only support load '
'official ckpt and mmcls ckpt.'
if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
else:
raise TypeError('pretrained must be a str or None')
num_layers = len(depths)
self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed
self.pretrain_style = pretrain_style
self.pretrained = pretrained
self.init_cfg = init_cfg
assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=strides[0],
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None)
if self.use_abs_pos_embed:
patch_row = pretrain_img_size[0] // patch_size
patch_col = pretrain_img_size[1] // patch_size
num_patches = patch_row * patch_col
self.absolute_pos_embed = nn.Parameter(
torch.zeros((1, num_patches, embed_dims)))
self.drop_after_pos = nn.Dropout(p=drop_rate)
# stochastic depth
total_depth = sum(depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
self.stages = ModuleList()
in_channels = embed_dims
for i in range(num_layers):
if i < num_layers - 1:
downsample = PatchMerging(
in_channels=in_channels,
out_channels=2 * in_channels,
stride=strides[i + 1],
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None)
else:
downsample = None
stage = SwinBlockSequence(
embed_dims=in_channels,
num_heads=num_heads[i],
feedforward_channels=mlp_ratio * in_channels,
depth=depths[i],
window_size=window_size,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[:depths[i]],
downsample=downsample,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
init_cfg=None)
self.stages.append(stage)
dpr = dpr[depths[i]:]
if downsample:
in_channels = downsample.out_channels
self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
# Add a norm layer for each output
for i in out_indices:
layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
layer_name = f'norm{i}'
self.add_module(layer_name, layer)
def init_weights(self):
if self.pretrained is None:
super().init_weights()
if self.use_abs_pos_embed:
trunc_normal_init(self.absolute_pos_embed, std=0.02)
for m in self.modules():
if isinstance(m, Linear):
trunc_normal_init(m.weight, std=.02)
if m.bias is not None:
constant_init(m.bias, 0)
elif isinstance(m, LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
elif isinstance(self.pretrained, str):
logger = get_root_logger()
ckpt = _load_checkpoint(
self.pretrained, logger=logger, map_location='cpu')
if 'state_dict' in ckpt:
state_dict = ckpt['state_dict']
elif 'model' in ckpt:
state_dict = ckpt['model']
else:
state_dict = ckpt
if self.pretrain_style == 'official':
state_dict = swin_convert(state_dict)
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
# reshape absolute position embedding
if state_dict.get('absolute_pos_embed') is not None:
absolute_pos_embed = state_dict['absolute_pos_embed']
N1, L, C1 = absolute_pos_embed.size()
N2, C2, H, W = self.absolute_pos_embed.size()
if N1 != N2 or C1 != C2 or L != H * W:
logger.warning('Error in loading absolute_pos_embed, pass')
else:
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
# interpolate position bias table if needed
relative_position_bias_table_keys = [
k for k in state_dict.keys()
if 'relative_position_bias_table' in k
]
for table_key in relative_position_bias_table_keys:
table_pretrained = state_dict[table_key]
table_current = self.state_dict()[table_key]
L1, nH1 = table_pretrained.size()
L2, nH2 = table_current.size()
if nH1 != nH2:
logger.warning(f'Error in loading {table_key}, pass')
else:
if L1 != L2:
S1 = int(L1**0.5)
S2 = int(L2**0.5)
table_pretrained_resized = F.interpolate(
table_pretrained.permute(1, 0).reshape(
1, nH1, S1, S1),
size=(S2, S2),
mode='bicubic')
state_dict[table_key] = table_pretrained_resized.view(
nH2, L2).permute(1, 0).contiguous()
# load state_dict
self.load_state_dict(state_dict, False)
def forward(self, x):
x = self.patch_embed(x)
hw_shape = (self.patch_embed.DH, self.patch_embed.DW)
if self.use_abs_pos_embed:
x = x + self.absolute_pos_embed
x = self.drop_after_pos(x)
outs = []
for i, stage in enumerate(self.stages):
x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(out)
out = out.view(-1, *out_hw_shape,
self.num_features[i]).permute(0, 3, 1,
2).contiguous()
outs.append(out)
return outs

View File

@ -4,8 +4,8 @@ import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init, from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init,
kaiming_init, normal_init, trunc_normal_init) normal_init, trunc_normal_init)
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
@ -13,7 +13,7 @@ from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.utils import get_root_logger from mmseg.utils import get_root_logger
from ..builder import BACKBONES from ..builder import BACKBONES
from ..utils import vit_convert from ..utils import PatchEmbed, vit_convert
class TransformerEncoderLayer(BaseModule): class TransformerEncoderLayer(BaseModule):
@ -93,49 +93,6 @@ class TransformerEncoderLayer(BaseModule):
return x return x
# Modified from pytorch-image-models
class PatchEmbed(BaseModule):
"""Image to Patch Embedding.
Args:
patch_size (int): The size of one patch
in_channels (int): The num of input channels.
embed_dims (int): The dimensions of embedding.
norm_cfg (dict, optional): Config dict for normalization layer.
conv_cfg (dict, optional): The config dict for conv layers.
Default: None.
"""
def __init__(self,
patch_size=16,
in_channels=3,
embed_dims=768,
norm_cfg=None,
conv_cfg=None):
super(PatchEmbed, self).__init__()
# Use conv layer to embed
self.projection = build_conv_layer(
conv_cfg,
in_channels,
embed_dims,
kernel_size=patch_size,
stride=patch_size)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
def forward(self, x):
x = self.projection(x).flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x
@BACKBONES.register_module() @BACKBONES.register_module()
class VisionTransformer(BaseModule): class VisionTransformer(BaseModule):
"""Vision Transformer. """Vision Transformer.
@ -248,10 +205,14 @@ class VisionTransformer(BaseModule):
self.init_cfg = init_cfg self.init_cfg = init_cfg
self.patch_embed = PatchEmbed( self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_channels=in_channels, in_channels=in_channels,
embed_dims=embed_dims, embed_dims=embed_dims,
norm_cfg=norm_cfg if patch_norm else None) conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None,
)
num_patches = (img_size[0] // patch_size) * \ num_patches = (img_size[0] // patch_size) * \
(img_size[1] // patch_size) (img_size[1] // patch_size)

View File

@ -1,12 +1,14 @@
from .ckpt_convert import swin_convert, vit_convert
from .embed import PatchEmbed
from .inverted_residual import InvertedResidual, InvertedResidualV3 from .inverted_residual import InvertedResidual, InvertedResidualV3
from .make_divisible import make_divisible from .make_divisible import make_divisible
from .res_layer import ResLayer from .res_layer import ResLayer
from .se_layer import SELayer from .se_layer import SELayer
from .self_attention_block import SelfAttentionBlock from .self_attention_block import SelfAttentionBlock
from .timm_convert import vit_convert
from .up_conv_block import UpConvBlock from .up_conv_block import UpConvBlock
__all__ = [ __all__ = [
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert' 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert',
'swin_convert', 'PatchEmbed'
] ]

View File

@ -0,0 +1,90 @@
from collections import OrderedDict
def swin_convert(ckpt):
new_ckpt = OrderedDict()
def correct_unfold_reduction_order(x):
out_channel, in_channel = x.shape
x = x.reshape(out_channel, 4, in_channel // 4)
x = x[:, [0, 2, 1, 3], :].transpose(1,
2).reshape(out_channel, in_channel)
return x
def correct_unfold_norm_order(x):
in_channel = x.shape[0]
x = x.reshape(4, in_channel // 4)
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
return x
for k, v in ckpt.items():
if k.startswith('head'):
continue
elif k.startswith('layers'):
new_v = v
if 'attn.' in k:
new_k = k.replace('attn.', 'attn.w_msa.')
elif 'mlp.' in k:
if 'mlp.fc1.' in k:
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
elif 'mlp.fc2.' in k:
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
else:
new_k = k.replace('mlp.', 'ffn.')
elif 'downsample' in k:
new_k = k
if 'reduction.' in k:
new_v = correct_unfold_reduction_order(v)
elif 'norm.' in k:
new_v = correct_unfold_norm_order(v)
else:
new_k = k
new_k = new_k.replace('layers', 'stages', 1)
elif k.startswith('patch_embed'):
new_v = v
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
else:
new_v = v
new_k = k
new_ckpt[new_k] = new_v
return new_ckpt
def vit_convert(ckpt):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
if k.startswith('head'):
continue
if k.startswith('norm'):
new_k = k.replace('norm.', 'ln1.')
elif k.startswith('patch_embed'):
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
elif k.startswith('blocks'):
if 'norm' in k:
new_k = k.replace('norm', 'ln')
elif 'mlp.fc1' in k:
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in k:
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
elif 'attn.qkv' in k:
new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_')
elif 'attn.proj' in k:
new_k = k.replace('attn.proj', 'attn.attn.out_proj')
else:
new_k = k
new_k = new_k.replace('blocks.', 'layers.')
else:
new_k = k
new_ckpt[new_k] = v
return new_ckpt

View File

@ -0,0 +1,89 @@
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner.base_module import BaseModule
from torch.nn.modules.utils import _pair as to_2tuple
# Modified from Pytorch-Image-Models
class PatchEmbed(BaseModule):
"""Image to Patch Embedding V2.
We use a conv layer to implement PatchEmbed.
Args:
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_type (dict, optional): The config dict for conv layers type
selection. Default: None.
kernel_size (int): The kernel_size of embedding conv. Default: 16.
stride (int): The slide stride of embedding conv.
Default: None (Default to be equal with kernel_size).
padding (int): The padding length of embedding conv. Default: 0.
dilation (int): The dilation rate of embedding conv. Default: 1.
norm_cfg (dict, optional): Config dict for normalization layer.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
in_channels=3,
embed_dims=768,
conv_type=None,
kernel_size=16,
stride=16,
padding=0,
dilation=1,
norm_cfg=None,
init_cfg=None):
super(PatchEmbed, self).__init__()
self.embed_dims = embed_dims
self.init_cfg = init_cfg
if stride is None:
stride = kernel_size
# The default setting of patch size is eaual to kernel size.
patch_size = kernel_size
if isinstance(patch_size, int):
patch_size = to_2tuple(patch_size)
elif isinstance(patch_size, tuple):
if len(patch_size) == 1:
patch_size = to_2tuple(patch_size[0])
assert len(patch_size) == 2, \
f'The size of patch should have length 1 or 2, ' \
f'but got {len(patch_size)}'
self.patch_size = patch_size
# Use conv layer to embed
conv_type = conv_type or dict(type='Conv2d')
self.projection = build_conv_layer(
dict(type=conv_type),
in_channels=in_channels,
out_channels=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
def forward(self, x):
H, W = x.shape[2], x.shape[3]
if H % self.patch_size[0] != 0:
x = F.pad(x,
(0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
if W % self.patch_size[1] != 0:
x = F.pad(x,
(0, self.patch_size[1] - W % self.patch_size[1], 0, 0))
x = self.projection(x)
self.DH, self.DW = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x

View File

@ -1,32 +0,0 @@
from collections import OrderedDict
def vit_convert(timm_dict):
mmseg_dict = OrderedDict()
for k, v in timm_dict.items():
if k.startswith('head'):
continue
if k.startswith('norm'):
new_k = k.replace('norm.', 'ln1.')
elif k.startswith('patch_embed'):
if 'proj' in k:
new_k = k.replace('proj', 'projection')
elif k.startswith('blocks'):
new_k = k.replace('blocks.', 'layers.')
if 'norm' in new_k:
new_k = new_k.replace('norm', 'ln')
elif 'mlp.fc1' in new_k:
new_k = new_k.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in new_k:
new_k = new_k.replace('mlp.fc2', 'ffn.layers.1')
elif 'attn.qkv' in new_k:
new_k = new_k.replace('attn.qkv.', 'attn.attn.in_proj_')
elif 'attn.proj' in new_k:
new_k = new_k.replace('attn.proj', 'attn.attn.out_proj')
else:
new_k = k
mmseg_dict[new_k] = v
return mmseg_dict

View File

@ -0,0 +1,64 @@
import pytest
import torch
from mmseg.models.backbones import SwinTransformer
def test_swin_transformer():
"""Test Swin Transformer backbone."""
with pytest.raises(AssertionError):
# We only support 'official' or 'mmcls' for this arg.
model = SwinTransformer(pretrain_style='swin')
with pytest.raises(TypeError):
# Pretrained arg must be str or None.
model = SwinTransformer(pretrained=123)
with pytest.raises(AssertionError):
# Because swin use non-overlapping patch embed, so the stride of patch
# embed must be equal to patch size.
model = SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)
# Test absolute position embedding
temp = torch.randn((1, 3, 224, 224))
model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)
model.init_weights()
model(temp)
# Test patch norm
model = SwinTransformer(patch_norm=False)
model(temp)
# Test pretrain img size
model = SwinTransformer(pretrain_img_size=(224, ))
with pytest.raises(AssertionError):
model = SwinTransformer(pretrain_img_size=(224, 224, 224))
# Test normal inference
temp = torch.randn((1, 3, 512, 512))
model = SwinTransformer()
outs = model(temp)
assert outs[0].shape == (1, 96, 128, 128)
assert outs[1].shape == (1, 192, 64, 64)
assert outs[2].shape == (1, 384, 32, 32)
assert outs[3].shape == (1, 768, 16, 16)
# Test abnormal inference
temp = torch.randn((1, 3, 511, 511))
model = SwinTransformer()
outs = model(temp)
assert outs[0].shape == (1, 96, 128, 128)
assert outs[1].shape == (1, 192, 64, 64)
assert outs[2].shape == (1, 384, 32, 32)
assert outs[3].shape == (1, 768, 16, 16)
# Test abnormal inference
temp = torch.randn((1, 3, 112, 137))
model = SwinTransformer()
outs = model(temp)
assert outs[0].shape == (1, 96, 28, 35)
assert outs[1].shape == (1, 192, 14, 18)
assert outs[2].shape == (1, 384, 7, 9)
assert outs[3].shape == (1, 768, 4, 5)

View File

@ -24,7 +24,7 @@ def test_vit_backbone():
x = torch.randn(1, 196) x = torch.randn(1, 196)
VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear') VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
with pytest.raises(RuntimeError): with pytest.raises(IndexError):
# forward inputs must be [N, C, H, W] # forward inputs must be [N, C, H, W]
x = torch.randn(3, 30, 30) x = torch.randn(3, 30, 30)
model = VisionTransformer() model = VisionTransformer()