diff --git a/README.md b/README.md index bbf24ec85..e13c6a780 100644 --- a/README.md +++ b/README.md @@ -59,11 +59,12 @@ Supported backbones: - [x] ResNet (CVPR'2016) - [x] ResNeXt (CVPR'2017) -- [x] [HRNet (CVPR'2019)](configs/hrnet/README.md) -- [x] [ResNeSt (ArXiv'2020)](configs/resnest/README.md) -- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2/README.md) -- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3/README.md) -- [x] [Vision Transformer (ICLR'2021)] +- [x] [HRNet (CVPR'2019)](configs/hrnet) +- [x] [ResNeSt (ArXiv'2020)](configs/resnest) +- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2) +- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3) +- [x] [Vision Transformer (ICLR'2021)](configs/vit) +- [x] [Swin Transformer (arXiV'2021)](configs/swin) Supported methods: @@ -71,7 +72,7 @@ Supported methods: - [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet) - [x] [PSPNet (CVPR'2017)](configs/pspnet) - [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3) -- [x] [Mixed Precision (FP16) Training (ArXiv'2017)](configs/fp16/README.md) +- [x] [Mixed Precision (FP16) Training (ArXiv'2017)](configs/fp16) - [x] [PSANet (ECCV'2018)](configs/psanet) - [x] [DeepLabV3+ (CVPR'2018)](configs/deeplabv3plus) - [x] [UPerNet (ECCV'2018)](configs/upernet) diff --git a/README_zh-CN.md b/README_zh-CN.md index 2341e4768..04191bd02 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -58,18 +58,20 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O - [x] ResNet (CVPR'2016) - [x] ResNeXt (CVPR'2017) -- [x] [HRNet (CVPR'2019)](configs/hrnet/README.md) -- [x] [ResNeSt (ArXiv'2020)](configs/resnest/README.md) -- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2/README.md) -- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3/README.md) -- [x] [Vision Transformer (ICLR'2021)] +- [x] [HRNet (CVPR'2019)](configs/hrnet) +- [x] [ResNeSt (ArXiv'2020)](configs/resnest) +- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2) +- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3) +- [x] [Vision Transformer (ICLR'2021)](configs/vit) +- [x] [Swin Transformer (arXiV'2021)](configs/swin) 已支持的算法: - [x] [FCN (CVPR'2015/TPAMI'2017)](configs/fcn) +- [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet) - [x] [PSPNet (CVPR'2017)](configs/pspnet) -- [x] [DeepLabV3 (CVPR'2017)](configs/deeplabv3) -- [x] [Mixed Precision (FP16) Training (ArXiv'2017)](configs/fp16/README.md) +- [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3) +- [x] [Mixed Precision (FP16) Training (ArXiv'2017)](configs/fp16) - [x] [PSANet (ECCV'2018)](configs/psanet) - [x] [DeepLabV3+ (CVPR'2018)](configs/deeplabv3plus) - [x] [UPerNet (ECCV'2018)](configs/upernet) diff --git a/configs/_base_/models/upernet_swin.py b/configs/_base_/models/upernet_swin.py new file mode 100644 index 000000000..30ee0503d --- /dev/null +++ b/configs/_base_/models/upernet_swin.py @@ -0,0 +1,55 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +backbone_norm_cfg = dict(type='LN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='SwinTransformer', + pretrain_img_size=224, + embed_dims=96, + patch_size=4, + window_size=7, + mlp_ratio=4, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + strides=(4, 2, 2, 2), + out_indices=(0, 1, 2, 3), + qkv_bias=True, + qk_scale=None, + patch_norm=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + use_abs_pos_embed=False, + act_cfg=dict(type='GELU'), + norm_cfg=backbone_norm_cfg, + pretrain_style='official'), + decode_head=dict( + type='UPerHead', + in_channels=[96, 192, 384, 768], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=384, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/configs/swin/README.md b/configs/swin/README.md new file mode 100644 index 000000000..8bf3a8ab6 --- /dev/null +++ b/configs/swin/README.md @@ -0,0 +1,27 @@ +# Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + +## Introduction + +[ALGORITHM] + +```latex +@article{liu2021Swin, + title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, + author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining}, + journal={arXiv preprint arXiv:2103.14030}, + year={2021} +} +``` + +## Results and models + +### ADE20K + +| Method | Backbone | Crop Size | pretrain | pretrain img size | Batch Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | +| ------ | -------- | --------- | ---------- | ------- | -------- | --- | --- | -------------- | ----- | ------------: | -------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| UperNet | Swin-T | 512x512 | ImageNet-1K | 224x224 | 16 | 160000 | 5.02 | 21.06 | 44.41 | 45.79 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210531_112542-e380ad3e.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210531_112542.log.json) | +| UperNet | Swin-S | 512x512 | ImageNet-1K | 224x224 | 16 | 160000 | 6.17 | 14.72 | 47.72 | 49.24 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192015-ee2fff1c.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192015.log.json) | +| UperNet | Swin-B | 512x512 | ImageNet-1K | 224x224 | 16 | 160000 | 7.61 | 12.65 | 47.99 | 49.57 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192340-593b0e13.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192340.log.json) | +| UperNet | Swin-B | 512x512 | ImageNet-22K | 224x224 | 16 | 160000 | - | - | 50.31 | 51.9 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K_20210526_211650-762e2178.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K_20210526_211650.log.json) | +| UperNet | Swin-B | 512x512 | ImageNet-1K | 384x384 | 16 | 160000 | 8.52 | 12.10 | 48.35 | 49.65 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K_20210531_132020-05b22ea4.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K_20210531_132020.log.json) | +| UperNet | Swin-B | 512x512 | ImageNet-22K | 384x384 | 16 | 160000 | - | - | 50.76 | 52.4 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K_20210531_125459-429057bf.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K_20210531_125459.log.json) | diff --git a/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K.py b/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K.py new file mode 100644 index 000000000..d89f57cab --- /dev/null +++ b/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K.py @@ -0,0 +1,15 @@ +_base_ = [ + 'upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_' + 'pretrain_224x224_1K.py' +] +model = dict( + pretrained=\ + 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth', # noqa + backbone=dict( + pretrain_img_size=384, + embed_dims=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=12), + decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150), + auxiliary_head=dict(in_channels=512, num_classes=150)) diff --git a/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K.py b/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K.py new file mode 100644 index 000000000..38fed2648 --- /dev/null +++ b/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K.py @@ -0,0 +1,8 @@ +_base_ = [ + './upernet_swin_base_patch4_window12_512x512_160k_ade20k_' + 'pretrain_384x384_1K.py' +] +model = dict( + pretrained=\ + 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', # noqa +) diff --git a/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py b/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py new file mode 100644 index 000000000..c34594a46 --- /dev/null +++ b/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py @@ -0,0 +1,13 @@ +_base_ = [ + './upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_' + 'pretrain_224x224_1K.py' +] +model = dict( + pretrained=\ + 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth', # noqa + backbone=dict( + embed_dims=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32]), + decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150), + auxiliary_head=dict(in_channels=512, num_classes=150)) diff --git a/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K.py b/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K.py new file mode 100644 index 000000000..5bb51d878 --- /dev/null +++ b/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K.py @@ -0,0 +1,8 @@ +_base_ = [ + './upernet_swin_base_patch4_window7_512x512_160k_ade20k_' + 'pretrain_224x224_1K.py' +] +model = dict( + pretrained=\ + 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', # noqa +) diff --git a/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py b/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py new file mode 100644 index 000000000..469b957c2 --- /dev/null +++ b/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py @@ -0,0 +1,17 @@ +_base_ = [ + './upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_' + 'pretrain_224x224_1K.py' +] +model = dict( + pretrained=\ + 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', # noqa + backbone=dict( + depths=[2, 2, 18, 2]), + decode_head=dict( + in_channels=[96, 192, 384, 768], + num_classes=150 + ), + auxiliary_head=dict( + in_channels=384, + num_classes=150 + )) diff --git a/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py b/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py new file mode 100644 index 000000000..7be1cf582 --- /dev/null +++ b/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py @@ -0,0 +1,46 @@ +_base_ = [ + '../_base_/models/upernet_swin.py', '../_base_/datasets/ade20k.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' +] +model = dict( + pretrained=\ + 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', # noqa + backbone=dict( + embed_dims=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + use_abs_pos_embed=False, + drop_path_rate=0.3, + patch_norm=True, + pretrain_style='official'), + decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150), + auxiliary_head=dict(in_channels=384, num_classes=150)) + +# AdamW optimizer, no weight decay for position embedding & layer norm +# in backbone +optimizer = dict( + _delete_=True, + type='AdamW', + lr=0.00006, + betas=(0.9, 0.999), + weight_decay=0.01, + paramwise_cfg=dict( + custom_keys={ + 'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + })) + +lr_config = dict( + _delete_=True, + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, + min_lr=0.0, + by_epoch=False) + +# By default, models are trained on 8 GPUs with 2 images per GPU +data = dict(samples_per_gpu=2) diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py index eae064b6e..43690d6c8 100644 --- a/mmseg/models/backbones/__init__.py +++ b/mmseg/models/backbones/__init__.py @@ -6,11 +6,12 @@ from .mobilenet_v3 import MobileNetV3 from .resnest import ResNeSt from .resnet import ResNet, ResNetV1c, ResNetV1d from .resnext import ResNeXt +from .swin import SwinTransformer from .unet import UNet from .vit import VisionTransformer __all__ = [ 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', - 'VisionTransformer' + 'VisionTransformer', 'SwinTransformer' ] diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py new file mode 100644 index 000000000..a798ad1eb --- /dev/null +++ b/mmseg/models/backbones/swin.py @@ -0,0 +1,778 @@ +import warnings +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer, trunc_normal_init +from mmcv.cnn.bricks.registry import ATTENTION +from mmcv.cnn.bricks.transformer import FFN, build_dropout +from mmcv.cnn.utils.weight_init import constant_init +from mmcv.runner import _load_checkpoint +from mmcv.runner.base_module import BaseModule, ModuleList +from torch.nn.modules.linear import Linear +from torch.nn.modules.normalization import LayerNorm +from torch.nn.modules.utils import _pair as to_2tuple + +from ...utils import get_root_logger +from ..builder import BACKBONES +from ..utils import PatchEmbed, swin_convert + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + This layer use nn.Unfold to group feature map by kernel_size, and use norm + and linear layer to embed grouped feature map. + Args: + in_channels (int): The num of input channels. + out_channels (int): The num of output channels. + stride (int | tuple): the stride of the sliding length in the + unfold layer. Defaults: 2. (Default to be equal with kernel_size). + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults: False. + norm_cfg (dict, optional): Config dict for normalization layer. + Defaults: dict(type='LN'). + init_cfg (dict, optional): The extra config for initialization. + Defaults: None. + """ + + def __init__(self, + in_channels, + out_channels, + stride=2, + bias=False, + norm_cfg=dict(type='LN'), + init_cfg=None): + super().__init__(init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + + self.sampler = nn.Unfold( + kernel_size=stride, dilation=1, padding=0, stride=stride) + + sample_dim = stride**2 * in_channels + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + def forward(self, x, hw_shape): + """ + x: x.shape -> [B, H*W, C] + hw_shape: (H, W) + """ + B, L, C = x.shape + H, W = hw_shape + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + + # stride is fixed to be equal to kernel_size. + if (H % self.stride != 0) or (W % self.stride != 0): + x = F.pad(x, (0, W % self.stride, 0, H % self.stride)) + + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + x = self.sampler(x) # B, 4*C, H/2*W/2 + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + + x = self.norm(x) if self.norm else x + x = self.reduction(x) + + down_hw_shape = (H + 1) // 2, (W + 1) // 2 + return x, down_hw_shape + + +@ATTENTION.register_module() +class WindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.0 + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None): + + super().__init__() + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + self.init_cfg = init_cfg + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # About 2x faster than original impl + Wh, Ww = self.window_size + rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww) + rel_position_index = rel_index_coords + rel_index_coords.T + rel_position_index = rel_position_index.flip(1).contiguous() + self.register_buffer('relative_position_index', rel_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + self.softmax = nn.Softmax(dim=-1) + + def init_weights(self): + trunc_normal_init(self.relative_position_bias_table, std=0.02) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor | None, Optional): mask with shape of (num_windows, + Wh*Ww, Wh*Ww), value should be between (-inf, 0]. + """ + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @staticmethod + def double_step_seq(step1, len1, step2, len2): + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +@ATTENTION.register_module() +class ShiftWindowMSA(BaseModule): + """Shift Window Multihead Self-Attention Module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. + shift_size (int, optional): The shift step of each window towards + right-bottom. If zero, act as regular window-msa. Defaults to 0. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Defaults: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Defaults: 0. + proj_drop_rate (float, optional): Dropout ratio of output. + Defaults: 0. + dropout_layer (dict, optional): The dropout_layer used before output. + Defaults: dict(type='DropPath', drop_prob=0.). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + shift_size=0, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0, + proj_drop_rate=0, + dropout_layer=dict(type='DropPath', drop_prob=0.), + init_cfg=None): + super().__init__(init_cfg) + + self.window_size = window_size + self.shift_size = shift_size + assert 0 <= self.shift_size < self.window_size + + self.w_msa = WindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=to_2tuple(window_size), + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=proj_drop_rate, + init_cfg=None) + + self.drop = build_dropout(dropout_layer) + + def forward(self, query, hw_shape): + B, L, C = query.shape + H, W = hw_shape + assert L == H * W, 'input feature has wrong size' + query = query.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b)) + H_pad, W_pad = query.shape[1], query.shape[2] + + # cyclic shift + if self.shift_size > 0: + shifted_query = torch.roll( + query, + shifts=(-self.shift_size, -self.shift_size), + dims=(1, 2)) + + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, H_pad, W_pad, 1), + device=query.device) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # nW, window_size, window_size, 1 + mask_windows = self.window_partition(img_mask) + mask_windows = mask_windows.view( + -1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + else: + shifted_query = query + attn_mask = None + + # nW*B, window_size, window_size, C + query_windows = self.window_partition(shifted_query) + # nW*B, window_size*window_size, C + query_windows = query_windows.view(-1, self.window_size**2, C) + + # W-MSA/SW-MSA (nW*B, window_size*window_size, C) + attn_windows = self.w_msa(query_windows, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + + # B H' W' C + shifted_x = self.window_reverse(attn_windows, H_pad, W_pad) + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = self.drop(x) + return x + + def window_reverse(self, windows, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + window_size = self.window_size + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + def window_partition(self, x): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + window_size = self.window_size + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + +class SwinBlock(BaseModule): + """" + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + window size (int, optional): The local window scale. Default: 7. + shift (bool): whether to shift window or not. Default False. + qkv_bias (int, optional): enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout rate. Default: 0. + attn_drop_rate (float, optional): Attention dropout rate. Default: 0. + drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2. + act_cfg (dict, optional): The config dict of activation function. + Default: dict(type='GELU'). + norm_cfg (dict, optional): The config dict of nomalization. + Default: dict(type='LN'). + init_cfg (dict | list | None, optional): The init config. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + window_size=7, + shift=False, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + + super(SwinBlock, self).__init__() + + self.init_cfg = init_cfg + + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = ShiftWindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=window_size, + shift_size=window_size // 2 if shift else 0, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + init_cfg=None) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=2, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=True, + init_cfg=None) + + def forward(self, x, hw_shape): + identity = x + x = self.norm1(x) + x = self.attn(x, hw_shape) + + x = x + identity + + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + +class SwinBlockSequence(BaseModule): + """Implements one stage in Swin Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + depth (int): The number of blocks in this stage. + window size (int): The local window scale. Default: 7. + qkv_bias (int): enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout rate. Default: 0. + attn_drop_rate (float, optional): Attention dropout rate. Default: 0. + drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2. + downsample (BaseModule | None, optional): The downsample operation + module. Default: None. + act_cfg (dict, optional): The config dict of activation function. + Default: dict(type='GELU'). + norm_cfg (dict, optional): The config dict of nomalization. + Default: dict(type='LN'). + init_cfg (dict | list | None, optional): The init config. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + depth, + window_size=7, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + downsample=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super().__init__() + + self.init_cfg = init_cfg + + drop_path_rate = drop_path_rate if isinstance( + drop_path_rate, + list) else [deepcopy(drop_path_rate) for _ in range(depth)] + + self.blocks = ModuleList() + for i in range(depth): + block = SwinBlock( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + window_size=window_size, + shift=False if i % 2 == 0 else True, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate[i], + act_cfg=act_cfg, + norm_cfg=norm_cfg, + init_cfg=None) + self.blocks.append(block) + + self.downsample = downsample + + def forward(self, x, hw_shape): + for block in self.blocks: + x = block(x, hw_shape) + + if self.downsample: + x_down, down_hw_shape = self.downsample(x, hw_shape) + return x_down, down_hw_shape, x, hw_shape + else: + return x, hw_shape, x, hw_shape + + +@BACKBONES.register_module() +class SwinTransformer(BaseModule): + """ Swin Transformer + A PyTorch implement of : `Swin Transformer: + Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/abs/2103.14030 + + Inspiration from + https://github.com/microsoft/Swin-Transformer + + Args: + pretrain_img_size (int | tuple[int]): The size of input image when + pretrain. Defaults: 224. + in_channels (int): The num of input channels. + Defaults: 3. + embed_dims (int): The feature dimension. Default: 96. + patch_size (int | tuple[int]): Patch size. Default: 4. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + Default: 4. + depths (tuple[int]): Depths of each Swin Transformer stage. + Default: (2, 2, 6, 2). + num_heads (tuple[int]): Parallel attention heads of each Swin + Transformer stage. Default: (3, 6, 12, 24). + strides (tuple[int]): The patch merging or patch embedding stride of + each Swin Transformer stage. (In swin, we set kernel size equal to + stride.) Default: (4, 2, 2, 2). + out_indices (tuple[int]): Output from which stages. + Default: (0, 1, 2, 3). + qkv_bias (bool, optional): If True, add a learnable bias to query, key, + value. Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + patch_norm (bool): If add a norm layer for patch embed and patch + merging. Default: True. + drop_rate (float): Dropout rate. Defaults: 0. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Defaults: 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults: False. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LN'). + norm_cfg (dict): Config dict for normalization layer at + output of backone. Defaults: dict(type='LN'). + pretrain_style (str): Choose to use official or mmcls pretrain weights. + Default: official. + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + pretrain_img_size=224, + in_channels=3, + embed_dims=96, + patch_size=4, + window_size=7, + mlp_ratio=4, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + strides=(4, 2, 2, 2), + out_indices=(0, 1, 2, 3), + qkv_bias=True, + qk_scale=None, + patch_norm=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + use_abs_pos_embed=False, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + pretrain_style='official', + pretrained=None, + init_cfg=None): + super(SwinTransformer, self).__init__() + + if isinstance(pretrain_img_size, int): + pretrain_img_size = to_2tuple(pretrain_img_size) + elif isinstance(pretrain_img_size, tuple): + if len(pretrain_img_size) == 1: + pretrain_img_size = to_2tuple(pretrain_img_size[0]) + assert len(pretrain_img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(pretrain_img_size)}' + + assert pretrain_style in ['official', 'mmcls'], 'We only support load ' + 'official ckpt and mmcls ckpt.' + + if isinstance(pretrained, str) or pretrained is None: + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + else: + raise TypeError('pretrained must be a str or None') + + num_layers = len(depths) + self.out_indices = out_indices + self.use_abs_pos_embed = use_abs_pos_embed + self.pretrain_style = pretrain_style + self.pretrained = pretrained + self.init_cfg = init_cfg + + assert strides[0] == patch_size, 'Use non-overlapping patch embed.' + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=strides[0], + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None) + + if self.use_abs_pos_embed: + patch_row = pretrain_img_size[0] // patch_size + patch_col = pretrain_img_size[1] // patch_size + num_patches = patch_row * patch_col + self.absolute_pos_embed = nn.Parameter( + torch.zeros((1, num_patches, embed_dims))) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + # stochastic depth + total_depth = sum(depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + self.stages = ModuleList() + in_channels = embed_dims + for i in range(num_layers): + if i < num_layers - 1: + downsample = PatchMerging( + in_channels=in_channels, + out_channels=2 * in_channels, + stride=strides[i + 1], + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None) + else: + downsample = None + + stage = SwinBlockSequence( + embed_dims=in_channels, + num_heads=num_heads[i], + feedforward_channels=mlp_ratio * in_channels, + depth=depths[i], + window_size=window_size, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[:depths[i]], + downsample=downsample, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + init_cfg=None) + self.stages.append(stage) + + dpr = dpr[depths[i]:] + if downsample: + in_channels = downsample.out_channels + + self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)] + # Add a norm layer for each output + for i in out_indices: + layer = build_norm_layer(norm_cfg, self.num_features[i])[1] + layer_name = f'norm{i}' + self.add_module(layer_name, layer) + + def init_weights(self): + if self.pretrained is None: + super().init_weights() + if self.use_abs_pos_embed: + trunc_normal_init(self.absolute_pos_embed, std=0.02) + for m in self.modules(): + if isinstance(m, Linear): + trunc_normal_init(m.weight, std=.02) + if m.bias is not None: + constant_init(m.bias, 0) + elif isinstance(m, LayerNorm): + constant_init(m.bias, 0) + constant_init(m.weight, 1.0) + elif isinstance(self.pretrained, str): + logger = get_root_logger() + ckpt = _load_checkpoint( + self.pretrained, logger=logger, map_location='cpu') + if 'state_dict' in ckpt: + state_dict = ckpt['state_dict'] + elif 'model' in ckpt: + state_dict = ckpt['model'] + else: + state_dict = ckpt + + if self.pretrain_style == 'official': + state_dict = swin_convert(state_dict) + + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # reshape absolute position embedding + if state_dict.get('absolute_pos_embed') is not None: + absolute_pos_embed = state_dict['absolute_pos_embed'] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = self.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H * W: + logger.warning('Error in loading absolute_pos_embed, pass') + else: + state_dict['absolute_pos_embed'] = absolute_pos_embed.view( + N2, H, W, C2).permute(0, 3, 1, 2).contiguous() + + # interpolate position bias table if needed + relative_position_bias_table_keys = [ + k for k in state_dict.keys() + if 'relative_position_bias_table' in k + ] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + table_current = self.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + logger.warning(f'Error in loading {table_key}, pass') + else: + if L1 != L2: + S1 = int(L1**0.5) + S2 = int(L2**0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).reshape( + 1, nH1, S1, S1), + size=(S2, S2), + mode='bicubic') + state_dict[table_key] = table_pretrained_resized.view( + nH2, L2).permute(1, 0).contiguous() + + # load state_dict + self.load_state_dict(state_dict, False) + + def forward(self, x): + x = self.patch_embed(x) + + hw_shape = (self.patch_embed.DH, self.patch_embed.DW) + if self.use_abs_pos_embed: + x = x + self.absolute_pos_embed + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape, out, out_hw_shape = stage(x, hw_shape) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(out) + out = out.view(-1, *out_hw_shape, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return outs diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 440b46310..1ad20a1ca 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -4,8 +4,8 @@ import warnings import torch import torch.nn as nn import torch.nn.functional as F -from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init, - kaiming_init, normal_init, trunc_normal_init) +from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init, + normal_init, trunc_normal_init) from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention from mmcv.runner import BaseModule, ModuleList, _load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm @@ -13,7 +13,7 @@ from torch.nn.modules.utils import _pair as to_2tuple from mmseg.utils import get_root_logger from ..builder import BACKBONES -from ..utils import vit_convert +from ..utils import PatchEmbed, vit_convert class TransformerEncoderLayer(BaseModule): @@ -93,49 +93,6 @@ class TransformerEncoderLayer(BaseModule): return x -# Modified from pytorch-image-models -class PatchEmbed(BaseModule): - """Image to Patch Embedding. - - Args: - patch_size (int): The size of one patch - in_channels (int): The num of input channels. - embed_dims (int): The dimensions of embedding. - norm_cfg (dict, optional): Config dict for normalization layer. - conv_cfg (dict, optional): The config dict for conv layers. - Default: None. - """ - - def __init__(self, - patch_size=16, - in_channels=3, - embed_dims=768, - norm_cfg=None, - conv_cfg=None): - super(PatchEmbed, self).__init__() - - # Use conv layer to embed - self.projection = build_conv_layer( - conv_cfg, - in_channels, - embed_dims, - kernel_size=patch_size, - stride=patch_size) - - if norm_cfg is not None: - self.norm = build_norm_layer(norm_cfg, embed_dims)[1] - else: - self.norm = None - - def forward(self, x): - x = self.projection(x).flatten(2).transpose(1, 2) - - if self.norm is not None: - x = self.norm(x) - - return x - - @BACKBONES.register_module() class VisionTransformer(BaseModule): """Vision Transformer. @@ -248,10 +205,14 @@ class VisionTransformer(BaseModule): self.init_cfg = init_cfg self.patch_embed = PatchEmbed( - patch_size=patch_size, in_channels=in_channels, embed_dims=embed_dims, - norm_cfg=norm_cfg if patch_norm else None) + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None, + ) num_patches = (img_size[0] // patch_size) * \ (img_size[1] // patch_size) diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py index b7066eb03..277dd2676 100644 --- a/mmseg/models/utils/__init__.py +++ b/mmseg/models/utils/__init__.py @@ -1,12 +1,14 @@ +from .ckpt_convert import swin_convert, vit_convert +from .embed import PatchEmbed from .inverted_residual import InvertedResidual, InvertedResidualV3 from .make_divisible import make_divisible from .res_layer import ResLayer from .se_layer import SELayer from .self_attention_block import SelfAttentionBlock -from .timm_convert import vit_convert from .up_conv_block import UpConvBlock __all__ = [ 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', - 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert' + 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert', + 'swin_convert', 'PatchEmbed' ] diff --git a/mmseg/models/utils/ckpt_convert.py b/mmseg/models/utils/ckpt_convert.py new file mode 100644 index 000000000..0b1b27707 --- /dev/null +++ b/mmseg/models/utils/ckpt_convert.py @@ -0,0 +1,90 @@ +from collections import OrderedDict + + +def swin_convert(ckpt): + new_ckpt = OrderedDict() + + def correct_unfold_reduction_order(x): + out_channel, in_channel = x.shape + x = x.reshape(out_channel, 4, in_channel // 4) + x = x[:, [0, 2, 1, 3], :].transpose(1, + 2).reshape(out_channel, in_channel) + return x + + def correct_unfold_norm_order(x): + in_channel = x.shape[0] + x = x.reshape(4, in_channel // 4) + x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) + return x + + for k, v in ckpt.items(): + if k.startswith('head'): + continue + elif k.startswith('layers'): + new_v = v + if 'attn.' in k: + new_k = k.replace('attn.', 'attn.w_msa.') + elif 'mlp.' in k: + if 'mlp.fc1.' in k: + new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') + elif 'mlp.fc2.' in k: + new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') + else: + new_k = k.replace('mlp.', 'ffn.') + elif 'downsample' in k: + new_k = k + if 'reduction.' in k: + new_v = correct_unfold_reduction_order(v) + elif 'norm.' in k: + new_v = correct_unfold_norm_order(v) + else: + new_k = k + new_k = new_k.replace('layers', 'stages', 1) + elif k.startswith('patch_embed'): + new_v = v + if 'proj' in k: + new_k = k.replace('proj', 'projection') + else: + new_k = k + else: + new_v = v + new_k = k + + new_ckpt[new_k] = new_v + + return new_ckpt + + +def vit_convert(ckpt): + + new_ckpt = OrderedDict() + + for k, v in ckpt.items(): + if k.startswith('head'): + continue + if k.startswith('norm'): + new_k = k.replace('norm.', 'ln1.') + elif k.startswith('patch_embed'): + if 'proj' in k: + new_k = k.replace('proj', 'projection') + else: + new_k = k + elif k.startswith('blocks'): + if 'norm' in k: + new_k = k.replace('norm', 'ln') + elif 'mlp.fc1' in k: + new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') + elif 'mlp.fc2' in k: + new_k = k.replace('mlp.fc2', 'ffn.layers.1') + elif 'attn.qkv' in k: + new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_') + elif 'attn.proj' in k: + new_k = k.replace('attn.proj', 'attn.attn.out_proj') + else: + new_k = k + new_k = new_k.replace('blocks.', 'layers.') + else: + new_k = k + new_ckpt[new_k] = v + + return new_ckpt diff --git a/mmseg/models/utils/embed.py b/mmseg/models/utils/embed.py new file mode 100644 index 000000000..3bbb45b37 --- /dev/null +++ b/mmseg/models/utils/embed.py @@ -0,0 +1,89 @@ +import torch.nn.functional as F +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmcv.runner.base_module import BaseModule +from torch.nn.modules.utils import _pair as to_2tuple + + +# Modified from Pytorch-Image-Models +class PatchEmbed(BaseModule): + """Image to Patch Embedding V2. + + We use a conv layer to implement PatchEmbed. + Args: + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + conv_type (dict, optional): The config dict for conv layers type + selection. Default: None. + kernel_size (int): The kernel_size of embedding conv. Default: 16. + stride (int): The slide stride of embedding conv. + Default: None (Default to be equal with kernel_size). + padding (int): The padding length of embedding conv. Default: 0. + dilation (int): The dilation rate of embedding conv. Default: 1. + norm_cfg (dict, optional): Config dict for normalization layer. + init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + in_channels=3, + embed_dims=768, + conv_type=None, + kernel_size=16, + stride=16, + padding=0, + dilation=1, + norm_cfg=None, + init_cfg=None): + super(PatchEmbed, self).__init__() + + self.embed_dims = embed_dims + self.init_cfg = init_cfg + + if stride is None: + stride = kernel_size + + # The default setting of patch size is eaual to kernel size. + patch_size = kernel_size + if isinstance(patch_size, int): + patch_size = to_2tuple(patch_size) + elif isinstance(patch_size, tuple): + if len(patch_size) == 1: + patch_size = to_2tuple(patch_size[0]) + assert len(patch_size) == 2, \ + f'The size of patch should have length 1 or 2, ' \ + f'but got {len(patch_size)}' + + self.patch_size = patch_size + + # Use conv layer to embed + conv_type = conv_type or dict(type='Conv2d') + self.projection = build_conv_layer( + dict(type=conv_type), + in_channels=in_channels, + out_channels=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + else: + self.norm = None + + def forward(self, x): + H, W = x.shape[2], x.shape[3] + if H % self.patch_size[0] != 0: + x = F.pad(x, + (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + if W % self.patch_size[1] != 0: + x = F.pad(x, + (0, self.patch_size[1] - W % self.patch_size[1], 0, 0)) + x = self.projection(x) + self.DH, self.DW = x.shape[2], x.shape[3] + x = x.flatten(2).transpose(1, 2) + + if self.norm is not None: + x = self.norm(x) + + return x diff --git a/mmseg/models/utils/timm_convert.py b/mmseg/models/utils/timm_convert.py deleted file mode 100644 index 2ce48b06d..000000000 --- a/mmseg/models/utils/timm_convert.py +++ /dev/null @@ -1,32 +0,0 @@ -from collections import OrderedDict - - -def vit_convert(timm_dict): - - mmseg_dict = OrderedDict() - - for k, v in timm_dict.items(): - if k.startswith('head'): - continue - if k.startswith('norm'): - new_k = k.replace('norm.', 'ln1.') - elif k.startswith('patch_embed'): - if 'proj' in k: - new_k = k.replace('proj', 'projection') - elif k.startswith('blocks'): - new_k = k.replace('blocks.', 'layers.') - if 'norm' in new_k: - new_k = new_k.replace('norm', 'ln') - elif 'mlp.fc1' in new_k: - new_k = new_k.replace('mlp.fc1', 'ffn.layers.0.0') - elif 'mlp.fc2' in new_k: - new_k = new_k.replace('mlp.fc2', 'ffn.layers.1') - elif 'attn.qkv' in new_k: - new_k = new_k.replace('attn.qkv.', 'attn.attn.in_proj_') - elif 'attn.proj' in new_k: - new_k = new_k.replace('attn.proj', 'attn.attn.out_proj') - else: - new_k = k - mmseg_dict[new_k] = v - - return mmseg_dict diff --git a/tests/test_models/test_backbones/test_swin.py b/tests/test_models/test_backbones/test_swin.py new file mode 100644 index 000000000..42e308667 --- /dev/null +++ b/tests/test_models/test_backbones/test_swin.py @@ -0,0 +1,64 @@ +import pytest +import torch + +from mmseg.models.backbones import SwinTransformer + + +def test_swin_transformer(): + """Test Swin Transformer backbone.""" + + with pytest.raises(AssertionError): + # We only support 'official' or 'mmcls' for this arg. + model = SwinTransformer(pretrain_style='swin') + + with pytest.raises(TypeError): + # Pretrained arg must be str or None. + model = SwinTransformer(pretrained=123) + + with pytest.raises(AssertionError): + # Because swin use non-overlapping patch embed, so the stride of patch + # embed must be equal to patch size. + model = SwinTransformer(strides=(2, 2, 2, 2), patch_size=4) + + # Test absolute position embedding + temp = torch.randn((1, 3, 224, 224)) + model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True) + model.init_weights() + model(temp) + + # Test patch norm + model = SwinTransformer(patch_norm=False) + model(temp) + + # Test pretrain img size + model = SwinTransformer(pretrain_img_size=(224, )) + + with pytest.raises(AssertionError): + model = SwinTransformer(pretrain_img_size=(224, 224, 224)) + + # Test normal inference + temp = torch.randn((1, 3, 512, 512)) + model = SwinTransformer() + outs = model(temp) + assert outs[0].shape == (1, 96, 128, 128) + assert outs[1].shape == (1, 192, 64, 64) + assert outs[2].shape == (1, 384, 32, 32) + assert outs[3].shape == (1, 768, 16, 16) + + # Test abnormal inference + temp = torch.randn((1, 3, 511, 511)) + model = SwinTransformer() + outs = model(temp) + assert outs[0].shape == (1, 96, 128, 128) + assert outs[1].shape == (1, 192, 64, 64) + assert outs[2].shape == (1, 384, 32, 32) + assert outs[3].shape == (1, 768, 16, 16) + + # Test abnormal inference + temp = torch.randn((1, 3, 112, 137)) + model = SwinTransformer() + outs = model(temp) + assert outs[0].shape == (1, 96, 28, 35) + assert outs[1].shape == (1, 192, 14, 18) + assert outs[2].shape == (1, 384, 7, 9) + assert outs[3].shape == (1, 768, 4, 5) diff --git a/tests/test_models/test_backbones/test_vit.py b/tests/test_models/test_backbones/test_vit.py index 007781f2f..4577b97b8 100644 --- a/tests/test_models/test_backbones/test_vit.py +++ b/tests/test_models/test_backbones/test_vit.py @@ -24,7 +24,7 @@ def test_vit_backbone(): x = torch.randn(1, 196) VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear') - with pytest.raises(RuntimeError): + with pytest.raises(IndexError): # forward inputs must be [N, C, H, W] x = torch.randn(3, 30, 30) model = VisionTransformer()