# Copyright (c) OpenMMLab. All rights reserved. from copy import deepcopy from typing import Sequence import torch import torch.nn as nn from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN from mmcv.cnn.utils.weight_init import trunc_normal_ from mmcv.runner.base_module import BaseModule, ModuleList from ..builder import BACKBONES from ..utils import PatchEmbed, PatchMerging, ShiftWindowMSA from .base_backbone import BaseBackbone class SwinBlock(BaseModule): """Swin Transformer block. Args: embed_dims (int): Number of input channels. input_resolution (Tuple[int, int]): The resolution of the input feature map. num_heads (int): Number of attention heads. window_size (int, optional): The height and width of the window. Defaults to 7. shift (bool, optional): Shift the attention window or not. Defaults to False. ffn_ratio (float, optional): The expansion ratio of feedforward network hidden layer channels. Defaults to 4. drop_path (float, optional): The drop path rate after attention and ffn. Defaults to 0. attn_cfgs (dict, optional): The extra config of Shift Window-MSA. Defaults to empty dict. ffn_cfgs (dict, optional): The extra config of FFN. Defaults to empty dict. norm_cfg (dict, optional): The config of norm layers. Defaults to dict(type='LN'). auto_pad (bool, optional): Auto pad the feature map to be divisible by window_size, Defaults to False. init_cfg (dict, optional): The extra config for initialization. Default: None. """ def __init__(self, embed_dims, input_resolution, num_heads, window_size=7, shift=False, ffn_ratio=4., drop_path=0., attn_cfgs=dict(), ffn_cfgs=dict(), norm_cfg=dict(type='LN'), auto_pad=False, init_cfg=None): super(SwinBlock, self).__init__(init_cfg) _attn_cfgs = { 'embed_dims': embed_dims, 'input_resolution': input_resolution, 'num_heads': num_heads, 'shift_size': window_size // 2 if shift else 0, 'window_size': window_size, 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), 'auto_pad': auto_pad, **attn_cfgs } self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] self.attn = ShiftWindowMSA(**_attn_cfgs) _ffn_cfgs = { 'embed_dims': embed_dims, 'feedforward_channels': int(embed_dims * ffn_ratio), 'num_fcs': 2, 'ffn_drop': 0, 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), 'act_cfg': dict(type='GELU'), **ffn_cfgs } self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] self.ffn = FFN(**_ffn_cfgs) def forward(self, x): identity = x x = self.norm1(x) x = self.attn(x) x = x + identity identity = x x = self.norm2(x) x = self.ffn(x, identity=identity) return x class SwinBlockSequence(BaseModule): """Module with successive Swin Transformer blocks and downsample layer. Args: embed_dims (int): Number of input channels. input_resolution (Tuple[int, int]): The resolution of the input feature map. depth (int): Number of successive swin transformer blocks. num_heads (int): Number of attention heads. downsample (bool, optional): Downsample the output of blocks by patch merging. Defaults to False. downsample_cfg (dict, optional): The extra config of the patch merging layer. Defaults to empty dict. drop_paths (Sequence[float] | float, optional): The drop path rate in each block. Defaults to 0. block_cfgs (Sequence[dict] | dict, optional): The extra config of each block. Defaults to empty dicts. auto_pad (bool, optional): Auto pad the feature map to be divisible by window_size, Defaults to False. init_cfg (dict, optional): The extra config for initialization. Default: None. """ def __init__(self, embed_dims, input_resolution, depth, num_heads, downsample=False, downsample_cfg=dict(), drop_paths=0., block_cfgs=dict(), auto_pad=False, init_cfg=None): super().__init__(init_cfg) if not isinstance(drop_paths, Sequence): drop_paths = [drop_paths] * depth if not isinstance(block_cfgs, Sequence): block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)] self.embed_dims = embed_dims self.input_resolution = input_resolution self.blocks = ModuleList() for i in range(depth): _block_cfg = { 'embed_dims': embed_dims, 'input_resolution': input_resolution, 'num_heads': num_heads, 'shift': False if i % 2 == 0 else True, 'drop_path': drop_paths[i], 'auto_pad': auto_pad, **block_cfgs[i] } block = SwinBlock(**_block_cfg) self.blocks.append(block) if downsample: _downsample_cfg = { 'input_resolution': input_resolution, 'in_channels': embed_dims, 'expansion_ratio': 2, 'norm_cfg': dict(type='LN'), **downsample_cfg } self.downsample = PatchMerging(**_downsample_cfg) else: self.downsample = None def forward(self, x): for block in self.blocks: x = block(x) if self.downsample: x = self.downsample(x) return x @property def out_resolution(self): if self.downsample: return self.downsample.output_resolution else: return self.input_resolution @property def out_channels(self): if self.downsample: return self.downsample.out_channels else: return self.embed_dims @BACKBONES.register_module() class SwinTransformer(BaseBackbone): """ Swin Transformer A PyTorch implement of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_ Inspiration from https://github.com/microsoft/Swin-Transformer Args: arch (str | dict): Swin Transformer architecture Defaults to 'T'. img_size (int | tuple): The size of input image. Defaults to 224. in_channels (int): The num of input channels. Defaults to 3. drop_rate (float): Dropout rate after embedding. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. use_abs_pos_embed (bool): If True, add absolute position embedding to the patch embedding. Defaults to False. auto_pad (bool): If True, auto pad feature map to fit window_size. Defaults to False. norm_cfg (dict, optional): Config dict for normalization layer at end of backone. Defaults to dict(type='LN') stage_cfgs (Sequence | dict, optional): Extra config dict for each stage. Defaults to empty dict. patch_cfg (dict, optional): Extra config dict for patch embedding. Defaults to empty dict. init_cfg (dict, optional): The Config for initialization. Defaults to None. Examples: >>> from mmcls.models import SwinTransformer >>> import torch >>> extra_config = dict( >>> arch='tiny', >>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3, >>> 'expansion_ratio': 3}), >>> auto_pad=True) >>> self = SwinTransformer(**extra_config) >>> inputs = torch.rand(1, 3, 224, 224) >>> output = self.forward(inputs) >>> print(output.shape) (1, 2592, 4) """ arch_zoo = { **dict.fromkeys(['t', 'tiny'], {'embed_dims': 96, 'depths': [2, 2, 6, 2], 'num_heads': [3, 6, 12, 24]}), **dict.fromkeys(['s', 'small'], {'embed_dims': 96, 'depths': [2, 2, 18, 2], 'num_heads': [3, 6, 12, 24]}), **dict.fromkeys(['b', 'base'], {'embed_dims': 128, 'depths': [2, 2, 18, 2], 'num_heads': [4, 8, 16, 32]}), **dict.fromkeys(['l', 'large'], {'embed_dims': 192, 'depths': [2, 2, 18, 2], 'num_heads': [6, 12, 24, 48]}), } # yapf: disable _version = 2 def __init__(self, arch='T', img_size=224, in_channels=3, drop_rate=0., drop_path_rate=0.1, out_indices=(3, ), use_abs_pos_embed=False, auto_pad=False, norm_cfg=dict(type='LN'), stage_cfgs=dict(), patch_cfg=dict(), init_cfg=None): super(SwinTransformer, self).__init__(init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in set(self.arch_zoo), \ f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: essential_keys = {'embed_dims', 'depths', 'num_head'} assert isinstance(arch, dict) and set(arch) == essential_keys, \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.embed_dims = self.arch_settings['embed_dims'] self.depths = self.arch_settings['depths'] self.num_heads = self.arch_settings['num_heads'] self.num_layers = len(self.depths) self.out_indices = out_indices self.use_abs_pos_embed = use_abs_pos_embed self.auto_pad = auto_pad _patch_cfg = { 'img_size': img_size, 'in_channels': in_channels, 'embed_dims': self.embed_dims, 'conv_cfg': dict(type='Conv2d', kernel_size=4, stride=4), 'norm_cfg': dict(type='LN'), **patch_cfg } self.patch_embed = PatchEmbed(**_patch_cfg) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution if self.use_abs_pos_embed: self.absolute_pos_embed = nn.Parameter( torch.zeros(1, num_patches, self.embed_dims)) self.drop_after_pos = nn.Dropout(p=drop_rate) # stochastic depth total_depth = sum(self.depths) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, total_depth) ] # stochastic depth decay rule self.stages = ModuleList() embed_dims = self.embed_dims input_resolution = patches_resolution for i, (depth, num_heads) in enumerate(zip(self.depths, self.num_heads)): if isinstance(stage_cfgs, Sequence): stage_cfg = stage_cfgs[i] else: stage_cfg = deepcopy(stage_cfgs) downsample = True if i < self.num_layers - 1 else False _stage_cfg = { 'embed_dims': embed_dims, 'depth': depth, 'num_heads': num_heads, 'downsample': downsample, 'input_resolution': input_resolution, 'drop_paths': dpr[:depth], 'auto_pad': auto_pad, **stage_cfg } stage = SwinBlockSequence(**_stage_cfg) self.stages.append(stage) dpr = dpr[depth:] embed_dims = stage.out_channels input_resolution = stage.out_resolution for i in out_indices: if norm_cfg is not None: norm_layer = build_norm_layer(norm_cfg, embed_dims)[1] else: norm_layer = nn.Identity() self.add_module(f'norm{i}', norm_layer) def init_weights(self): super(SwinTransformer, self).init_weights() if (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): # Suppress default init if use pretrained model. return if self.use_abs_pos_embed: trunc_normal_(self.absolute_pos_embed, std=0.02) def forward(self, x): x = self.patch_embed(x) if self.use_abs_pos_embed: x = x + self.absolute_pos_embed x = self.drop_after_pos(x) outs = [] for i, stage in enumerate(self.stages): x = stage(x) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') out = norm_layer(x) out = out.view(-1, *stage.out_resolution, stage.out_channels).permute(0, 3, 1, 2).contiguous() outs.append(out) return tuple(outs) def _load_from_state_dict(self, state_dict, prefix, local_metadata, *args, **kwargs): """load checkpoints.""" # Names of some parameters in has been changed. version = local_metadata.get('version', None) if (version is None or version < 2) and self.__class__ is SwinTransformer: final_stage_num = len(self.stages) - 1 state_dict_keys = list(state_dict.keys()) for k in state_dict_keys: if k.startswith('norm.') or k.startswith('backbone.norm.'): convert_key = k.replace('norm.', f'norm{final_stage_num}.') state_dict[convert_key] = state_dict[k] del state_dict[k] super()._load_from_state_dict(state_dict, prefix, local_metadata, *args, **kwargs)