From c708770b427df933e63aff6be0cfe571a18e868a Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Thu, 3 Mar 2022 13:10:12 +0800 Subject: [PATCH] [Enhance] Support dynamic input shape for ViT-based algorithms. (#706) * Move `resize_pos_embed` to `mmcls.models.utils` * Refactor Vision Transformer * Refactor DeiT * Refactor MLP-Mixer * Refactor Swin-Transformer * Remove `indexing` arg * Support dynamic inputs for t2t_vit * Add copyright * Fix bugs in swin transformer * Add `pad_small_maps` option * Update swin transformer * Handle `attn_mask` in checkpoints of swin * Imporve by comments --- mmcls/models/backbones/deit.py | 60 ++- mmcls/models/backbones/mlp_mixer.py | 39 +- mmcls/models/backbones/swin_transformer.py | 256 ++++++----- mmcls/models/backbones/t2t_vit.py | 137 ++++-- mmcls/models/backbones/vision_transformer.py | 140 +++--- mmcls/models/utils/__init__.py | 4 +- mmcls/models/utils/attention.py | 249 ++++++----- mmcls/models/utils/embed.py | 54 +++ tests/test_downstream/test_mmdet_inference.py | 3 +- tests/test_models/test_backbones/__init__.py | 1 + tests/test_models/test_backbones/test_deit.py | 148 +++++-- .../test_backbones/test_mlp_mixer.py | 124 ++++-- .../test_backbones/test_swin_transformer.py | 399 +++++++++--------- .../test_backbones/test_t2t_vit.py | 207 ++++++--- .../test_backbones/test_vision_transformer.py | 281 ++++++------ tests/test_models/test_backbones/utils.py | 31 ++ .../test_models/test_utils/test_attention.py | 297 +++++++------ 17 files changed, 1483 insertions(+), 947 deletions(-) create mode 100644 tests/test_models/test_backbones/__init__.py create mode 100644 tests/test_models/test_backbones/utils.py diff --git a/mmcls/models/backbones/deit.py b/mmcls/models/backbones/deit.py index 37851798..56e74e07 100644 --- a/mmcls/models/backbones/deit.py +++ b/mmcls/models/backbones/deit.py @@ -15,21 +15,38 @@ class DistilledVisionTransformer(VisionTransformer): distillation through attention `_ Args: - arch (str | dict): Vision Transformer architecture - Default: 'b' - img_size (int | tuple): Input image size - patch_size (int | tuple): The patch size + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' + and 'deit-base'. If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'deit-base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. output_cls_token (bool): Whether output the cls_token. If set True, - `with_cls_token` must be True. Defaults to True. + ``with_cls_token`` must be True. Defaults to True. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. @@ -40,22 +57,31 @@ class DistilledVisionTransformer(VisionTransformer): """ num_extra_tokens = 2 # cls_token, dist_token - def __init__(self, *args, **kwargs): - super(DistilledVisionTransformer, self).__init__(*args, **kwargs) + def __init__(self, arch='deit-base', *args, **kwargs): + super(DistilledVisionTransformer, self).__init__( + arch=arch, *args, **kwargs) self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) def forward(self, x): B = x.shape[0] - x = self.patch_embed(x) - patch_resolution = self.patch_embed.patches_resolution + x, patch_resolution = self.patch_embed(x) # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(B, -1, -1) dist_token = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_token, x), dim=1) - x = x + self.pos_embed + x = x + self.resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) x = self.drop_after_pos(x) + if not self.with_cls_token: + # Remove class token for transformer encoder input + x = x[:, 2:] + outs = [] for i, layer in enumerate(self.layers): x = layer(x) @@ -65,10 +91,16 @@ class DistilledVisionTransformer(VisionTransformer): if i in self.out_indices: B, _, C = x.shape - patch_token = x[:, 2:].reshape(B, *patch_resolution, C) - patch_token = patch_token.permute(0, 3, 1, 2) - cls_token = x[:, 0] - dist_token = x[:, 1] + if self.with_cls_token: + patch_token = x[:, 2:].reshape(B, *patch_resolution, C) + patch_token = patch_token.permute(0, 3, 1, 2) + cls_token = x[:, 0] + dist_token = x[:, 1] + else: + patch_token = x.reshape(B, *patch_resolution, C) + patch_token = patch_token.permute(0, 3, 1, 2) + cls_token = None + dist_token = None if self.output_cls_token: out = [patch_token, cls_token, dist_token] else: diff --git a/mmcls/models/backbones/mlp_mixer.py b/mmcls/models/backbones/mlp_mixer.py index 1e4a51f2..13171a4b 100644 --- a/mmcls/models/backbones/mlp_mixer.py +++ b/mmcls/models/backbones/mlp_mixer.py @@ -3,11 +3,11 @@ from typing import Sequence import torch.nn as nn from mmcv.cnn import build_norm_layer -from mmcv.cnn.bricks.transformer import FFN +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed from mmcv.runner.base_module import BaseModule, ModuleList from ..builder import BACKBONES -from ..utils import PatchEmbed, to_2tuple +from ..utils import to_2tuple from .base_backbone import BaseBackbone @@ -105,10 +105,20 @@ class MlpMixer(BaseBackbone): `_ Args: - arch (str | dict): MLP Mixer architecture - Defaults to 'b'. - img_size (int | tuple): Input image size. - patch_size (int | tuple): The patch size. + arch (str | dict): MLP Mixer architecture. If use string, choose from + 'small', 'base' and 'large'. If use dict, it should have below + keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of MLP blocks. + - **tokens_mlp_dims** (int): The hidden dimensions for tokens FFNs. + - **channels_mlp_dims** (int): The The hidden dimensions for + channels FFNs. + + Defaults to 'base'. + img_size (int | tuple): The input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. out_indices (Sequence | int): Output from which layer. Defaults to -1, means the last layer. drop_rate (float): Probability of an element to be zeroed. @@ -149,7 +159,7 @@ class MlpMixer(BaseBackbone): } def __init__(self, - arch='b', + arch='base', img_size=224, patch_size=16, out_indices=-1, @@ -184,14 +194,16 @@ class MlpMixer(BaseBackbone): self.img_size = to_2tuple(img_size) _patch_cfg = dict( - img_size=img_size, + input_size=img_size, embed_dims=self.embed_dims, - conv_cfg=dict( - type='Conv2d', kernel_size=patch_size, stride=patch_size), + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, ) _patch_cfg.update(patch_cfg) self.patch_embed = PatchEmbed(**_patch_cfg) - num_patches = self.patch_embed.num_patches + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] if isinstance(out_indices, int): out_indices = [out_indices] @@ -232,7 +244,10 @@ class MlpMixer(BaseBackbone): return getattr(self, self.norm1_name) def forward(self, x): - x = self.patch_embed(x) + assert x.shape[2:] == self.img_size, \ + "The MLP-Mixer doesn't support dynamic input shape. " \ + f'Please input images with shape {self.img_size}' + x, _ = self.patch_embed(x) outs = [] for i, layer in enumerate(self.layers): diff --git a/mmcls/models/backbones/swin_transformer.py b/mmcls/models/backbones/swin_transformer.py index 04a8d14a..966f6acd 100644 --- a/mmcls/models/backbones/swin_transformer.py +++ b/mmcls/models/backbones/swin_transformer.py @@ -2,17 +2,18 @@ from copy import deepcopy from typing import Sequence +import numpy as np import torch import torch.nn as nn import torch.utils.checkpoint as cp from mmcv.cnn import build_norm_layer -from mmcv.cnn.bricks.transformer import FFN +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed, PatchMerging from mmcv.cnn.utils.weight_init import trunc_normal_ from mmcv.runner.base_module import BaseModule, ModuleList from mmcv.utils.parrots_wrapper import _BatchNorm from ..builder import BACKBONES -from ..utils import PatchEmbed, PatchMerging, ShiftWindowMSA +from ..utils import ShiftWindowMSA, resize_pos_embed, to_2tuple from .base_backbone import BaseBackbone @@ -21,45 +22,41 @@ class SwinBlock(BaseModule): Args: embed_dims (int): Number of input channels. - input_resolution (Tuple[int, int]): The resolution of the input feature - map. num_heads (int): Number of attention heads. - window_size (int, optional): The height and width of the window. - Defaults to 7. - shift (bool, optional): Shift the attention window or not. + window_size (int): The height and width of the window. Defaults to 7. + shift (bool): Shift the attention window or not. Defaults to False. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. Defaults to False. - ffn_ratio (float, optional): The expansion ratio of feedforward network - hidden layer channels. Defaults to 4. - drop_path (float, optional): The drop path rate after attention and - ffn. Defaults to 0. - attn_cfgs (dict, optional): The extra config of Shift Window-MSA. + attn_cfgs (dict): The extra config of Shift Window-MSA. Defaults to empty dict. - ffn_cfgs (dict, optional): The extra config of FFN. - Defaults to empty dict. - norm_cfg (dict, optional): The config of norm layers. - Defaults to dict(type='LN'). - with_cp (bool, optional): Use checkpoint or not. Using checkpoint - will save some memory while slowing down the training speed. - Defaults to False. - auto_pad (bool, optional): Auto pad the feature map to be divisible by - window_size, Defaults to False. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. init_cfg (dict, optional): The extra config for initialization. - Default: None. + Defaults to None. """ def __init__(self, embed_dims, - input_resolution, num_heads, window_size=7, shift=False, ffn_ratio=4., drop_path=0., + pad_small_map=False, attn_cfgs=dict(), ffn_cfgs=dict(), norm_cfg=dict(type='LN'), with_cp=False, - auto_pad=False, init_cfg=None): super(SwinBlock, self).__init__(init_cfg) @@ -67,12 +64,11 @@ class SwinBlock(BaseModule): _attn_cfgs = { 'embed_dims': embed_dims, - 'input_resolution': input_resolution, 'num_heads': num_heads, 'shift_size': window_size // 2 if shift else 0, 'window_size': window_size, 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), - 'auto_pad': auto_pad, + 'pad_small_map': pad_small_map, **attn_cfgs } self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] @@ -90,12 +86,12 @@ class SwinBlock(BaseModule): self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] self.ffn = FFN(**_ffn_cfgs) - def forward(self, x): + def forward(self, x, hw_shape): def _inner_forward(x): identity = x x = self.norm1(x) - x = self.attn(x) + x = self.attn(x, hw_shape) x = x + identity identity = x @@ -117,38 +113,39 @@ class SwinBlockSequence(BaseModule): Args: embed_dims (int): Number of input channels. - input_resolution (Tuple[int, int]): The resolution of the input feature - map. depth (int): Number of successive swin transformer blocks. num_heads (int): Number of attention heads. - downsample (bool, optional): Downsample the output of blocks by patch - merging. Defaults to False. - downsample_cfg (dict, optional): The extra config of the patch merging - layer. Defaults to empty dict. - drop_paths (Sequence[float] | float, optional): The drop path rate in - each block. Defaults to 0. - block_cfgs (Sequence[dict] | dict, optional): The extra config of each - block. Defaults to empty dicts. - with_cp (bool, optional): Use checkpoint or not. Using checkpoint - will save some memory while slowing down the training speed. + window_size (int): The height and width of the window. Defaults to 7. + downsample (bool): Downsample the output of blocks by patch merging. + Defaults to False. + downsample_cfg (dict): The extra config of the patch merging layer. + Defaults to empty dict. + drop_paths (Sequence[float] | float): The drop path rate in each block. + Defaults to 0. + block_cfgs (Sequence[dict] | dict): The extra config of each block. + Defaults to empty dicts. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. Defaults to False. - auto_pad (bool, optional): Auto pad the feature map to be divisible by - window_size, Defaults to False. init_cfg (dict, optional): The extra config for initialization. - Default: None. + Defaults to None. """ def __init__(self, embed_dims, - input_resolution, depth, num_heads, + window_size=7, downsample=False, downsample_cfg=dict(), drop_paths=0., block_cfgs=dict(), with_cp=False, - auto_pad=False, + pad_small_map=False, init_cfg=None): super().__init__(init_cfg) @@ -159,17 +156,16 @@ class SwinBlockSequence(BaseModule): block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)] self.embed_dims = embed_dims - self.input_resolution = input_resolution self.blocks = ModuleList() for i in range(depth): _block_cfg = { 'embed_dims': embed_dims, - 'input_resolution': input_resolution, 'num_heads': num_heads, + 'window_size': window_size, 'shift': False if i % 2 == 0 else True, 'drop_path': drop_paths[i], 'with_cp': with_cp, - 'auto_pad': auto_pad, + 'pad_small_map': pad_small_map, **block_cfgs[i] } block = SwinBlock(**_block_cfg) @@ -177,9 +173,8 @@ class SwinBlockSequence(BaseModule): if downsample: _downsample_cfg = { - 'input_resolution': input_resolution, 'in_channels': embed_dims, - 'expansion_ratio': 2, + 'out_channels': 2 * embed_dims, 'norm_cfg': dict(type='LN'), **downsample_cfg } @@ -187,20 +182,15 @@ class SwinBlockSequence(BaseModule): else: self.downsample = None - def forward(self, x): + def forward(self, x, in_shape): for block in self.blocks: - x = block(x) + x = block(x, in_shape) if self.downsample: - x = self.downsample(x) - return x - - @property - def out_resolution(self): - if self.downsample: - return self.downsample.output_resolution + x, out_shape = self.downsample(x, in_shape) else: - return self.input_resolution + out_shape = in_shape + return x, out_shape @property def out_channels(self): @@ -212,7 +202,8 @@ class SwinBlockSequence(BaseModule): @BACKBONES.register_module() class SwinTransformer(BaseBackbone): - """ Swin Transformer + """Swin Transformer. + A PyTorch implement of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_ @@ -221,34 +212,47 @@ class SwinTransformer(BaseBackbone): https://github.com/microsoft/Swin-Transformer Args: - arch (str | dict): Swin Transformer architecture - Defaults to 'T'. - img_size (int | tuple): The size of input image. - Defaults to 224. - in_channels (int): The num of input channels. - Defaults to 3. - drop_rate (float): Dropout rate after embedding. - Defaults to 0. - drop_path_rate (float): Stochastic depth rate. - Defaults to 0.1. + arch (str | dict): Swin Transformer architecture. If use string, choose + from 'tiny', 'small', 'base' and 'large'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **num_heads** (List[int]): The number of heads in attention + modules of each stage. + + Defaults to 'tiny'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 4. + in_channels (int): The num of input channels. Defaults to 3. + window_size (int): The height and width of the window. Defaults to 7. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. use_abs_pos_embed (bool): If True, add absolute position embedding to the patch embedding. Defaults to False. - with_cp (bool, optional): Use checkpoint or not. Using checkpoint - will save some memory while slowing down the training speed. - Defaults to False. + interpolate_mode (str): Select the interpolate mode for absolute + position embeding vector resize. Defaults to "bicubic". + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Defaults to False. - auto_pad (bool): If True, auto pad feature map to fit window_size. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. Defaults to False. - norm_cfg (dict, optional): Config dict for normalization layer at end - of backone. Defaults to dict(type='LN') - stage_cfgs (Sequence | dict, optional): Extra config dict for each - stage. Defaults to empty dict. - patch_cfg (dict, optional): Extra config dict for patch embedding. - Defaults to empty dict. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN')`` + stage_cfgs (Sequence[dict] | dict): Extra config dict for each + stage. Defaults to an empty dict. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. init_cfg (dict, optional): The Config for initialization. Defaults to None. @@ -258,8 +262,7 @@ class SwinTransformer(BaseBackbone): >>> extra_config = dict( >>> arch='tiny', >>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3, - >>> 'expansion_ratio': 3}), - >>> auto_pad=True) + >>> 'expansion_ratio': 3})) >>> self = SwinTransformer(**extra_config) >>> inputs = torch.rand(1, 3, 224, 224) >>> output = self.forward(inputs) @@ -285,25 +288,29 @@ class SwinTransformer(BaseBackbone): 'num_heads': [6, 12, 24, 48]}), } # yapf: disable - _version = 2 + _version = 3 + num_extra_tokens = 0 def __init__(self, - arch='T', + arch='tiny', img_size=224, + patch_size=4, in_channels=3, + window_size=7, drop_rate=0., drop_path_rate=0.1, out_indices=(3, ), use_abs_pos_embed=False, - auto_pad=False, + interpolate_mode='bicubic', with_cp=False, frozen_stages=-1, norm_eval=False, + pad_small_map=False, norm_cfg=dict(type='LN'), stage_cfgs=dict(), patch_cfg=dict(), init_cfg=None): - super(SwinTransformer, self).__init__(init_cfg) + super(SwinTransformer, self).__init__(init_cfg=init_cfg) if isinstance(arch, str): arch = arch.lower() @@ -311,7 +318,7 @@ class SwinTransformer(BaseBackbone): f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: - essential_keys = {'embed_dims', 'depths', 'num_head'} + essential_keys = {'embed_dims', 'depths', 'num_heads'} assert isinstance(arch, dict) and set(arch) == essential_keys, \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch @@ -322,26 +329,28 @@ class SwinTransformer(BaseBackbone): self.num_layers = len(self.depths) self.out_indices = out_indices self.use_abs_pos_embed = use_abs_pos_embed - self.auto_pad = auto_pad + self.interpolate_mode = interpolate_mode self.frozen_stages = frozen_stages - self.num_extra_tokens = 0 - _patch_cfg = { - 'img_size': img_size, - 'in_channels': in_channels, - 'embed_dims': self.embed_dims, - 'conv_cfg': dict(type='Conv2d', kernel_size=4, stride=4), - 'norm_cfg': dict(type='LN'), - **patch_cfg - } + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + norm_cfg=dict(type='LN'), + ) + _patch_cfg.update(patch_cfg) self.patch_embed = PatchEmbed(**_patch_cfg) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution + self.patch_resolution = self.patch_embed.init_out_size if self.use_abs_pos_embed: + num_patches = self.patch_resolution[0] * self.patch_resolution[1] self.absolute_pos_embed = nn.Parameter( torch.zeros(1, num_patches, self.embed_dims)) + self._register_load_state_dict_pre_hook( + self._prepare_abs_pos_embed) self.drop_after_pos = nn.Dropout(p=drop_rate) self.norm_eval = norm_eval @@ -354,7 +363,6 @@ class SwinTransformer(BaseBackbone): self.stages = ModuleList() embed_dims = [self.embed_dims] - input_resolution = patches_resolution for i, (depth, num_heads) in enumerate(zip(self.depths, self.num_heads)): if isinstance(stage_cfgs, Sequence): @@ -366,11 +374,11 @@ class SwinTransformer(BaseBackbone): 'embed_dims': embed_dims[-1], 'depth': depth, 'num_heads': num_heads, + 'window_size': window_size, 'downsample': downsample, - 'input_resolution': input_resolution, 'drop_paths': dpr[:depth], 'with_cp': with_cp, - 'auto_pad': auto_pad, + 'pad_small_map': pad_small_map, **stage_cfg } @@ -379,7 +387,6 @@ class SwinTransformer(BaseBackbone): dpr = dpr[depth:] embed_dims.append(stage.out_channels) - input_resolution = stage.out_resolution for i in out_indices: if norm_cfg is not None: @@ -401,18 +408,20 @@ class SwinTransformer(BaseBackbone): trunc_normal_(self.absolute_pos_embed, std=0.02) def forward(self, x): - x = self.patch_embed(x) + x, hw_shape = self.patch_embed(x) if self.use_abs_pos_embed: - x = x + self.absolute_pos_embed + x = x + resize_pos_embed( + self.absolute_pos_embed, self.patch_resolution, hw_shape, + self.interpolate_mode, self.num_extra_tokens) x = self.drop_after_pos(x) outs = [] for i, stage in enumerate(self.stages): - x = stage(x) + x, hw_shape = stage(x, hw_shape) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') out = norm_layer(x) - out = out.view(-1, *stage.out_resolution, + out = out.view(-1, *hw_shape, stage.out_channels).permute(0, 3, 1, 2).contiguous() outs.append(out) @@ -433,6 +442,12 @@ class SwinTransformer(BaseBackbone): convert_key = k.replace('norm.', f'norm{final_stage_num}.') state_dict[convert_key] = state_dict[k] del state_dict[k] + if (version is None + or version < 3) and self.__class__ is SwinTransformer: + state_dict_keys = list(state_dict.keys()) + for k in state_dict_keys: + if 'attn_mask' in k: + del state_dict[k] super()._load_from_state_dict(state_dict, prefix, local_metadata, *args, **kwargs) @@ -461,3 +476,26 @@ class SwinTransformer(BaseBackbone): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval() + + def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'absolute_pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.absolute_pos_embed.shape != ckpt_pos_embed_shape: + from mmcls.utils import get_root_logger + logger = get_root_logger() + logger.info( + 'Resize the absolute_pos_embed shape from ' + f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) diff --git a/mmcls/models/backbones/t2t_vit.py b/mmcls/models/backbones/t2t_vit.py index f28a7b1b..e3160ccd 100644 --- a/mmcls/models/backbones/t2t_vit.py +++ b/mmcls/models/backbones/t2t_vit.py @@ -11,7 +11,7 @@ from mmcv.cnn.utils.weight_init import trunc_normal_ from mmcv.runner.base_module import BaseModule, ModuleList from ..builder import BACKBONES -from ..utils import MultiheadAttention +from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple from .base_backbone import BaseBackbone @@ -173,26 +173,44 @@ class T2TModule(BaseModule): raise NotImplementedError("Performer hasn't been implemented.") # there are 3 soft split, stride are 4,2,2 separately - self.num_patches = (img_size // (4 * 2 * 2))**2 + out_side = img_size // (4 * 2 * 2) + self.init_out_size = [out_side, out_side] + self.num_patches = out_side**2 + + @staticmethod + def _get_unfold_size(unfold: nn.Unfold, input_size): + h, w = input_size + kernel_size = to_2tuple(unfold.kernel_size) + stride = to_2tuple(unfold.stride) + padding = to_2tuple(unfold.padding) + dilation = to_2tuple(unfold.dilation) + + h_out = (h + 2 * padding[0] - dilation[0] * + (kernel_size[0] - 1) - 1) // stride[0] + 1 + w_out = (w + 2 * padding[1] - dilation[1] * + (kernel_size[1] - 1) - 1) // stride[1] + 1 + return (h_out, w_out) def forward(self, x): # step0: soft split + hw_shape = self._get_unfold_size(self.soft_split0, x.shape[2:]) x = self.soft_split0(x).transpose(1, 2) for step in [1, 2]: # re-structurization/reconstruction attn = getattr(self, f'attention{step}') x = attn(x).transpose(1, 2) - B, C, new_HW = x.shape - x = x.reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) + B, C, _ = x.shape + x = x.reshape(B, C, hw_shape[0], hw_shape[1]) # soft split soft_split = getattr(self, f'soft_split{step}') + hw_shape = self._get_unfold_size(soft_split, hw_shape) x = soft_split(x).transpose(1, 2) # final tokens x = self.project(x) - return x + return x, hw_shape def get_sinusoid_encoding(n_position, embed_dims): @@ -231,43 +249,52 @@ class T2T_ViT(BaseBackbone): Transformers from Scratch on ImageNet `_ Args: - img_size (int): Input image size. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. in_channels (int): Number of input channels. embed_dims (int): Embedding dimension. - t2t_cfg (dict): Extra config of Tokens-to-Token module. - Defaults to an empty dict. - drop_rate (float): Dropout rate after position embedding. - Defaults to 0. num_layers (int): Num of transformer layers in encoder. Defaults to 14. out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. - layer_cfgs (Sequence | dict): Configs of each transformer layer in - encoder. Defaults to an empty dict. + drop_rate (float): Dropout rate after position embedding. + Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. - output_cls_token (bool): Whether output the cls_token. - Defaults to True. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + output_cls_token (bool): Whether output the cls_token. If set True, + ``with_cls_token`` must be True. Defaults to True. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + t2t_cfg (dict): Extra config of Tokens-to-Token module. + Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. init_cfg (dict, optional): The Config for initialization. Defaults to None. """ + num_extra_tokens = 1 # cls_token def __init__(self, img_size=224, in_channels=3, embed_dims=384, - t2t_cfg=dict(), - drop_rate=0., num_layers=14, out_indices=-1, - layer_cfgs=dict(), + drop_rate=0., drop_path_rate=0., norm_cfg=dict(type='LN'), final_norm=True, + with_cls_token=True, output_cls_token=True, + interpolate_mode='bicubic', + t2t_cfg=dict(), + layer_cfgs=dict(), init_cfg=None): super(T2T_ViT, self).__init__(init_cfg) @@ -277,30 +304,41 @@ class T2T_ViT(BaseBackbone): in_channels=in_channels, embed_dims=embed_dims, **t2t_cfg) - num_patches = self.tokens_to_token.num_patches + self.patch_resolution = self.tokens_to_token.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] - # Class token + # Set cls token + if output_cls_token: + assert with_cls_token is True, f'with_cls_token must be True if' \ + f'set output_cls_token to True, but got {with_cls_token}' + self.with_cls_token = with_cls_token self.output_cls_token = output_cls_token self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) - self.num_extra_tokens = 1 - # Position Embedding - sinusoid_table = get_sinusoid_encoding(num_patches + 1, embed_dims) + # Set position embedding + self.interpolate_mode = interpolate_mode + sinusoid_table = get_sinusoid_encoding( + num_patches + self.num_extra_tokens, embed_dims) self.register_buffer('pos_embed', sinusoid_table) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + self.drop_after_pos = nn.Dropout(p=drop_rate) if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ - f'"out_indices" must by a sequence or int, ' \ + f'"out_indices" must be a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = num_layers + index - assert out_indices[i] >= 0, f'Invalid out_indices {index}' + assert 0 <= out_indices[i] <= num_layers, \ + f'Invalid out_indices {index}' self.out_indices = out_indices + # stochastic depth decay rule dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)] + self.encoder = ModuleList() for i in range(num_layers): if isinstance(layer_cfgs, Sequence): @@ -336,17 +374,49 @@ class T2T_ViT(BaseBackbone): trunc_normal_(self.cls_token, std=.02) + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmcls.utils import get_root_logger + logger = get_root_logger() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.tokens_to_token.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + def forward(self, x): B = x.shape[0] - x = self.tokens_to_token(x) - num_patches = self.tokens_to_token.num_patches - patch_resolution = [int(np.sqrt(num_patches))] * 2 + x, patch_resolution = self.tokens_to_token(x) + # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) - x = x + self.pos_embed + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) x = self.drop_after_pos(x) + if not self.with_cls_token: + # Remove class token for transformer encoder input + x = x[:, 1:] + outs = [] for i, layer in enumerate(self.encoder): x = layer(x) @@ -356,9 +426,14 @@ class T2T_ViT(BaseBackbone): if i in self.out_indices: B, _, C = x.shape - patch_token = x[:, 1:].reshape(B, *patch_resolution, C) - patch_token = patch_token.permute(0, 3, 1, 2) - cls_token = x[:, 0] + if self.with_cls_token: + patch_token = x[:, 1:].reshape(B, *patch_resolution, C) + patch_token = patch_token.permute(0, 3, 1, 2) + cls_token = x[:, 0] + else: + patch_token = x.reshape(B, *patch_resolution, C) + patch_token = patch_token.permute(0, 3, 1, 2) + cls_token = None if self.output_cls_token: out = [patch_token, cls_token] else: diff --git a/mmcls/models/backbones/vision_transformer.py b/mmcls/models/backbones/vision_transformer.py index fc0f57b5..87a70640 100644 --- a/mmcls/models/backbones/vision_transformer.py +++ b/mmcls/models/backbones/vision_transformer.py @@ -4,15 +4,14 @@ from typing import Sequence import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from mmcv.cnn import build_norm_layer -from mmcv.cnn.bricks.transformer import FFN +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed from mmcv.cnn.utils.weight_init import trunc_normal_ from mmcv.runner.base_module import BaseModule, ModuleList from mmcls.utils import get_root_logger from ..builder import BACKBONES -from ..utils import MultiheadAttention, PatchEmbed, to_2tuple +from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple from .base_backbone import BaseBackbone @@ -108,21 +107,38 @@ class VisionTransformer(BaseBackbone): for Image Recognition at Scale `_ Args: - arch (str | dict): Vision Transformer architecture - Default: 'b' - img_size (int | tuple): Input image size - patch_size (int | tuple): The patch size + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' + and 'deit-base'. If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. output_cls_token (bool): Whether output the cls_token. If set True, - `with_cls_token` must be True. Defaults to True. + ``with_cls_token`` must be True. Defaults to True. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. @@ -138,7 +154,6 @@ class VisionTransformer(BaseBackbone): 'num_layers': 8, 'num_heads': 8, 'feedforward_channels': 768 * 3, - 'qkv_bias': False }), **dict.fromkeys( ['b', 'base'], { @@ -180,14 +195,17 @@ class VisionTransformer(BaseBackbone): num_extra_tokens = 1 # cls_token def __init__(self, - arch='b', + arch='base', img_size=224, patch_size=16, + in_channels=3, out_indices=-1, drop_rate=0., drop_path_rate=0., + qkv_bias=True, norm_cfg=dict(type='LN', eps=1e-6), final_norm=True, + with_cls_token=True, output_cls_token=True, interpolate_mode='bicubic', patch_cfg=dict(), @@ -214,16 +232,23 @@ class VisionTransformer(BaseBackbone): # Set patch embedding _patch_cfg = dict( - img_size=img_size, + in_channels=in_channels, + input_size=img_size, embed_dims=self.embed_dims, - conv_cfg=dict( - type='Conv2d', kernel_size=patch_size, stride=patch_size), + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, ) _patch_cfg.update(patch_cfg) self.patch_embed = PatchEmbed(**_patch_cfg) - num_patches = self.patch_embed.num_patches + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] # Set cls token + if output_cls_token: + assert with_cls_token is True, f'with_cls_token must be True if' \ + f'set output_cls_token to True, but got {with_cls_token}' + self.with_cls_token = with_cls_token self.output_cls_token = output_cls_token self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) @@ -232,6 +257,8 @@ class VisionTransformer(BaseBackbone): self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + self.num_extra_tokens, self.embed_dims)) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + self.drop_after_pos = nn.Dropout(p=drop_rate) if isinstance(out_indices, int): @@ -242,11 +269,12 @@ class VisionTransformer(BaseBackbone): for i, index in enumerate(out_indices): if index < 0: out_indices[i] = self.num_layers + index - assert out_indices[i] >= 0, f'Invalid out_indices {index}' + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}' self.out_indices = out_indices # stochastic depth decay rule - dpr = np.linspace(0, drop_path_rate, self.arch_settings['num_layers']) + dpr = np.linspace(0, drop_path_rate, self.num_layers) self.layers = ModuleList() if isinstance(layer_cfgs, dict): @@ -259,7 +287,7 @@ class VisionTransformer(BaseBackbone): arch_settings['feedforward_channels'], drop_rate=drop_rate, drop_path_rate=dpr[i], - qkv_bias=self.arch_settings.get('qkv_bias', True), + qkv_bias=qkv_bias, norm_cfg=norm_cfg) _layer_cfg.update(layer_cfgs[i]) self.layers.append(TransformerEncoderLayer(**_layer_cfg)) @@ -270,8 +298,6 @@ class VisionTransformer(BaseBackbone): norm_cfg, self.embed_dims, postfix=1) self.add_module(self.norm1_name, norm1) - self._register_load_state_dict_pre_hook(self._prepare_checkpoint_hook) - @property def norm1(self): return getattr(self, self.norm1_name) @@ -283,7 +309,7 @@ class VisionTransformer(BaseBackbone): and self.init_cfg['type'] == 'Pretrained'): trunc_normal_(self.pos_embed, std=0.02) - def _prepare_checkpoint_hook(self, state_dict, prefix, *args, **kwargs): + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): name = prefix + 'pos_embed' if name not in state_dict.keys(): return @@ -299,61 +325,38 @@ class VisionTransformer(BaseBackbone): ckpt_pos_embed_shape = to_2tuple( int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) - pos_embed_shape = self.patch_embed.patches_resolution + pos_embed_shape = self.patch_embed.init_out_size - state_dict[name] = self.resize_pos_embed(state_dict[name], - ckpt_pos_embed_shape, - pos_embed_shape, - self.interpolate_mode, - self.num_extra_tokens) + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) @staticmethod - def resize_pos_embed(pos_embed, - src_shape, - dst_shape, - mode='bicubic', - num_extra_tokens=1): - """Resize pos_embed weights. - - Args: - pos_embed (torch.Tensor): Position embedding weights with shape - [1, L, C]. - src_shape (tuple): The resolution of downsampled origin training - image. - dst_shape (tuple): The resolution of downsampled new training - image. - mode (str): Algorithm used for upsampling: - ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | - ``'trilinear'``. Default: ``'bicubic'`` - Return: - torch.Tensor: The resized pos_embed of shape [1, L_new, C] - """ - assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]' - _, L, C = pos_embed.shape - src_h, src_w = src_shape - assert L == src_h * src_w + num_extra_tokens - extra_tokens = pos_embed[:, :num_extra_tokens] - - src_weight = pos_embed[:, num_extra_tokens:] - src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) - - dst_weight = F.interpolate( - src_weight, size=dst_shape, align_corners=False, mode=mode) - dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2) - - return torch.cat((extra_tokens, dst_weight), dim=1) + def resize_pos_embed(*args, **kwargs): + """Interface for backward-compatibility.""" + return resize_pos_embed(*args, **kwargs) def forward(self, x): B = x.shape[0] - x = self.patch_embed(x) - patch_resolution = self.patch_embed.patches_resolution + x, patch_resolution = self.patch_embed(x) # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) - x = x + self.pos_embed + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) x = self.drop_after_pos(x) + if not self.with_cls_token: + # Remove class token for transformer encoder input + x = x[:, 1:] + outs = [] for i, layer in enumerate(self.layers): x = layer(x) @@ -363,9 +366,14 @@ class VisionTransformer(BaseBackbone): if i in self.out_indices: B, _, C = x.shape - patch_token = x[:, 1:].reshape(B, *patch_resolution, C) - patch_token = patch_token.permute(0, 3, 1, 2) - cls_token = x[:, 0] + if self.with_cls_token: + patch_token = x[:, 1:].reshape(B, *patch_resolution, C) + patch_token = patch_token.permute(0, 3, 1, 2) + cls_token = x[:, 0] + else: + patch_token = x.reshape(B, *patch_resolution, C) + patch_token = patch_token.permute(0, 3, 1, 2) + cls_token = None if self.output_cls_token: out = [patch_token, cls_token] else: diff --git a/mmcls/models/utils/__init__.py b/mmcls/models/utils/__init__.py index aaf30c3e..395915ec 100644 --- a/mmcls/models/utils/__init__.py +++ b/mmcls/models/utils/__init__.py @@ -2,7 +2,7 @@ from .attention import MultiheadAttention, ShiftWindowMSA from .augment.augments import Augments from .channel_shuffle import channel_shuffle -from .embed import HybridEmbed, PatchEmbed, PatchMerging +from .embed import HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple from .inverted_residual import InvertedResidual from .make_divisible import make_divisible @@ -13,5 +13,5 @@ __all__ = [ 'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer', 'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'PatchEmbed', 'PatchMerging', 'HybridEmbed', 'Augments', 'ShiftWindowMSA', 'is_tracing', - 'MultiheadAttention', 'ConditionalPositionEncoding' + 'MultiheadAttention', 'ConditionalPositionEncoding', 'resize_pos_embed' ] diff --git a/mmcls/models/utils/attention.py b/mmcls/models/utils/attention.py index 65954112..155127f7 100644 --- a/mmcls/models/utils/attention.py +++ b/mmcls/models/utils/attention.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings + import torch import torch.nn as nn import torch.nn.functional as F @@ -126,14 +128,12 @@ class ShiftWindowMSA(BaseModule): Args: embed_dims (int): Number of input channels. - input_resolution (Tuple[int, int]): The resolution of the input feature - map. num_heads (int): Number of attention heads. window_size (int): The height and width of the window. shift_size (int, optional): The shift step of each window towards right-bottom. If zero, act as regular window-msa. Defaults to 0. qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. - Default: True + Defaults to True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. Defaults to None. attn_drop (float, optional): Dropout ratio of attention weight. @@ -141,15 +141,17 @@ class ShiftWindowMSA(BaseModule): proj_drop (float, optional): Dropout ratio of output. Defaults to 0. dropout_layer (dict, optional): The dropout_layer used before output. Defaults to dict(type='DropPath', drop_prob=0.). - auto_pad (bool, optional): Auto pad the feature map to be divisible by - window_size, Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. init_cfg (dict, optional): The extra config for initialization. - Default: None. + Defaults to None. """ def __init__(self, embed_dims, - input_resolution, num_heads, window_size, shift_size=0, @@ -158,53 +160,134 @@ class ShiftWindowMSA(BaseModule): attn_drop=0, proj_drop=0, dropout_layer=dict(type='DropPath', drop_prob=0.), - auto_pad=False, + pad_small_map=False, + input_resolution=None, + auto_pad=None, init_cfg=None): super().__init__(init_cfg) - self.embed_dims = embed_dims - self.input_resolution = input_resolution + if input_resolution is not None or auto_pad is not None: + warnings.warn( + 'The ShiftWindowMSA in new version has supported auto padding ' + 'and dynamic input shape in all condition. And the argument ' + '`auto_pad` and `input_resolution` have been deprecated.', + DeprecationWarning) + self.shift_size = shift_size self.window_size = window_size - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, don't partition - self.shift_size = 0 - self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size - self.w_msa = WindowMSA(embed_dims, to_2tuple(self.window_size), - num_heads, qkv_bias, qk_scale, attn_drop, - proj_drop) + self.w_msa = WindowMSA( + embed_dims=embed_dims, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) self.drop = build_dropout(dropout_layer) + self.pad_small_map = pad_small_map - H, W = self.input_resolution - # Handle auto padding - self.auto_pad = auto_pad - if self.auto_pad: - self.pad_r = (self.window_size - - W % self.window_size) % self.window_size - self.pad_b = (self.window_size - - H % self.window_size) % self.window_size - self.H_pad = H + self.pad_b - self.W_pad = W + self.pad_r - else: - H_pad, W_pad = self.input_resolution - assert H_pad % self.window_size + W_pad % self.window_size == 0,\ - f'input_resolution({self.input_resolution}) is not divisible '\ - f'by window_size({self.window_size}). Please check feature '\ - f'map shape or set `auto_pad=True`.' - self.H_pad, self.W_pad = H_pad, W_pad - self.pad_r, self.pad_b = 0, 0 + def forward(self, query, hw_shape): + B, L, C = query.shape + H, W = hw_shape + assert L == H * W, f"The query length {L} doesn't match the input "\ + f'shape ({H}, {W}).' + query = query.view(B, H, W, C) + window_size = self.window_size + shift_size = self.shift_size + + if min(H, W) == window_size: + # If not pad small feature map, avoid shifting when the window size + # is equal to the size of feature map. It's to align with the + # behavior of the original implementation. + shift_size = shift_size if self.pad_small_map else 0 + elif min(H, W) < window_size: + # In the original implementation, the window size will be shrunk + # to the size of feature map. The behavior is different with + # swin-transformer for downstream tasks. To support dynamic input + # shape, we don't allow this feature. + assert self.pad_small_map, \ + f'The input shape ({H}, {W}) is smaller than the window ' \ + f'size ({window_size}). Please set `pad_small_map=True`, or ' \ + 'decrease the `window_size`.' + + pad_r = (window_size - W % window_size) % window_size + pad_b = (window_size - H % window_size) % window_size + query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b)) + + H_pad, W_pad = query.shape[1], query.shape[2] + + # cyclic shift + if shift_size > 0: + query = torch.roll( + query, shifts=(-shift_size, -shift_size), dims=(1, 2)) + + attn_mask = self.get_attn_mask((H_pad, W_pad), + window_size=window_size, + shift_size=shift_size, + device=query.device) + + # nW*B, window_size, window_size, C + query_windows = self.window_partition(query, window_size) + # nW*B, window_size*window_size, C + query_windows = query_windows.view(-1, window_size**2, C) + + # W-MSA/SW-MSA (nW*B, window_size*window_size, C) + attn_windows = self.w_msa(query_windows, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, window_size, window_size, C) + + # B H' W' C + shifted_x = self.window_reverse(attn_windows, H_pad, W_pad, + window_size) + # reverse cyclic shift if self.shift_size > 0: - # calculate attention mask for SW-MSA - img_mask = torch.zeros((1, self.H_pad, self.W_pad, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, - -self.shift_size), slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, - -self.shift_size), slice(-self.shift_size, None)) + x = torch.roll( + shifted_x, shifts=(shift_size, shift_size), dims=(1, 2)) + else: + x = shifted_x + + if H != H_pad or W != W_pad: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = self.drop(x) + + return x + + @staticmethod + def window_reverse(windows, H, W, window_size): + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + @staticmethod + def window_partition(x, window_size): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + @staticmethod + def get_attn_mask(hw_shape, window_size, shift_size, device=None): + if shift_size > 0: + img_mask = torch.zeros(1, *hw_shape, 1, device=device) + h_slices = (slice(0, -window_size), slice(-window_size, + -shift_size), + slice(-shift_size, None)) + w_slices = (slice(0, -window_size), slice(-window_size, + -shift_size), + slice(-shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: @@ -212,83 +295,15 @@ class ShiftWindowMSA(BaseModule): cnt += 1 # nW, window_size, window_size, 1 - mask_windows = self.window_partition(img_mask) - mask_windows = mask_windows.view( - -1, self.window_size * self.window_size) + mask_windows = ShiftWindowMSA.window_partition( + img_mask, window_size) + mask_windows = mask_windows.view(-1, window_size * window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, - float(-100.0)).masked_fill( - attn_mask == 0, float(0.0)) + attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0) + attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0) else: attn_mask = None - - self.register_buffer('attn_mask', attn_mask) - - def forward(self, query): - H, W = self.input_resolution - B, L, C = query.shape - assert L == H * W, 'input feature has wrong size' - query = query.view(B, H, W, C) - - if self.pad_r or self.pad_b: - query = F.pad(query, (0, 0, 0, self.pad_r, 0, self.pad_b)) - - # cyclic shift - if self.shift_size > 0: - shifted_query = torch.roll( - query, - shifts=(-self.shift_size, -self.shift_size), - dims=(1, 2)) - else: - shifted_query = query - - # nW*B, window_size, window_size, C - query_windows = self.window_partition(shifted_query) - # nW*B, window_size*window_size, C - query_windows = query_windows.view(-1, self.window_size**2, C) - - # W-MSA/SW-MSA (nW*B, window_size*window_size, C) - attn_windows = self.w_msa(query_windows, mask=self.attn_mask) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, - self.window_size, C) - - # B H' W' C - shifted_x = self.window_reverse(attn_windows, self.H_pad, self.W_pad) - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll( - shifted_x, - shifts=(self.shift_size, self.shift_size), - dims=(1, 2)) - else: - x = shifted_x - - if self.pad_r or self.pad_b: - x = x[:, :H, :W, :].contiguous() - - x = x.view(B, H * W, C) - - x = self.drop(x) - return x - - def window_reverse(self, windows, H, W): - window_size = self.window_size - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, - window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - def window_partition(self, x): - B, H, W, C = x.shape - window_size = self.window_size - x = x.view(B, H // window_size, window_size, W // window_size, - window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() - windows = windows.view(-1, window_size, window_size, C) - return windows + return attn_mask class MultiheadAttention(BaseModule): diff --git a/mmcls/models/utils/embed.py b/mmcls/models/utils/embed.py index b5f7be27..ffe2c855 100644 --- a/mmcls/models/utils/embed.py +++ b/mmcls/models/utils/embed.py @@ -1,12 +1,59 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings + import torch import torch.nn as nn +import torch.nn.functional as F from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.runner.base_module import BaseModule from .helpers import to_2tuple +def resize_pos_embed(pos_embed, + src_shape, + dst_shape, + mode='bicubic', + num_extra_tokens=1): + """Resize pos_embed weights. + + Args: + pos_embed (torch.Tensor): Position embedding weights with shape + [1, L, C]. + src_shape (tuple): The resolution of downsampled origin training + image, in format (H, W). + dst_shape (tuple): The resolution of downsampled new training + image, in format (H, W). + mode (str): Algorithm used for upsampling. Choose one from 'nearest', + 'linear', 'bilinear', 'bicubic' and 'trilinear'. + Defaults to 'bicubic'. + num_extra_tokens (int): The number of extra tokens, such as cls_token. + Defaults to 1. + + Returns: + torch.Tensor: The resized pos_embed of shape [1, L_new, C] + """ + if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]: + return pos_embed + assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]' + _, L, C = pos_embed.shape + src_h, src_w = src_shape + assert L == src_h * src_w + num_extra_tokens, \ + f"The length of `pos_embed` ({L}) doesn't match the expected " \ + f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \ + '`img_size` argument.' + extra_tokens = pos_embed[:, :num_extra_tokens] + + src_weight = pos_embed[:, num_extra_tokens:] + src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) + + dst_weight = F.interpolate( + src_weight, size=dst_shape, align_corners=False, mode=mode) + dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2) + + return torch.cat((extra_tokens, dst_weight), dim=1) + + class PatchEmbed(BaseModule): """Image to Patch Embedding. @@ -32,6 +79,9 @@ class PatchEmbed(BaseModule): conv_cfg=None, init_cfg=None): super(PatchEmbed, self).__init__(init_cfg) + warnings.warn('The `PatchEmbed` in mmcls will be deprecated. ' + 'Please use `mmcv.cnn.bricks.transformer.PatchEmbed`. ' + "It's more general and supports dynamic input shape") if isinstance(img_size, int): img_size = to_2tuple(img_size) @@ -203,6 +253,10 @@ class PatchMerging(BaseModule): norm_cfg=dict(type='LN'), init_cfg=None): super().__init__(init_cfg) + warnings.warn('The `PatchMerging` in mmcls will be deprecated. ' + 'Please use `mmcv.cnn.bricks.transformer.PatchMerging`. ' + "It's more general and supports dynamic input shape") + H, W = input_resolution self.input_resolution = input_resolution self.in_channels = in_channels diff --git a/tests/test_downstream/test_mmdet_inference.py b/tests/test_downstream/test_mmdet_inference.py index 6da3ba16..096c5db7 100644 --- a/tests/test_downstream/test_mmdet_inference.py +++ b/tests/test_downstream/test_mmdet_inference.py @@ -52,8 +52,7 @@ backbone_configs = dict( arch='small', drop_path_rate=0.2, img_size=800, - out_indices=(2, 3), - auto_pad=True), + out_indices=(2, 3)), out_channels=[384, 768]), timm_efficientnet=dict( backbone=dict( diff --git a/tests/test_models/test_backbones/__init__.py b/tests/test_models/test_backbones/__init__.py new file mode 100644 index 00000000..ef101fec --- /dev/null +++ b/tests/test_models/test_backbones/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_backbones/test_deit.py b/tests/test_models/test_backbones/test_deit.py index af9efe9c..5f11a3ae 100644 --- a/tests/test_models/test_backbones/test_deit.py +++ b/tests/test_models/test_backbones/test_deit.py @@ -1,43 +1,131 @@ # Copyright (c) OpenMMLab. All rights reserved. +import math +import os +import tempfile +from copy import deepcopy +from unittest import TestCase import torch -from torch.nn.modules.batchnorm import _BatchNorm +from mmcv.runner import load_checkpoint, save_checkpoint from mmcls.models.backbones import DistilledVisionTransformer +from .utils import timm_resize_pos_embed -def check_norm_state(modules, train_state): - """Check if norm layer is in correct train state.""" - for mod in modules: - if isinstance(mod, _BatchNorm): - if mod.training != train_state: - return False - return True +class TestDeiT(TestCase): + def setUp(self): + self.cfg = dict( + arch='deit-base', img_size=224, patch_size=16, drop_rate=0.1) -def test_deit_backbone(): - cfg_ori = dict(arch='deit-b', img_size=224, patch_size=16) + def test_init_weights(self): + # test weight init cfg + cfg = deepcopy(self.cfg) + cfg['init_cfg'] = [ + dict( + type='Kaiming', + layer='Conv2d', + mode='fan_in', + nonlinearity='linear') + ] + model = DistilledVisionTransformer(**cfg) + ori_weight = model.patch_embed.projection.weight.clone().detach() + # The pos_embed is all zero before initialize + self.assertTrue(torch.allclose(model.dist_token, torch.tensor(0.))) - # Test structure - model = DistilledVisionTransformer(**cfg_ori) - model.init_weights() - model.train() + model.init_weights() + initialized_weight = model.patch_embed.projection.weight + self.assertFalse(torch.allclose(ori_weight, initialized_weight)) + self.assertFalse(torch.allclose(model.dist_token, torch.tensor(0.))) - assert check_norm_state(model.modules(), True) - assert model.dist_token.shape == (1, 1, 768) - assert model.pos_embed.shape == (1, model.patch_embed.num_patches + 2, 768) + # test load checkpoint + pretrain_pos_embed = model.pos_embed.clone().detach() + tmpdir = tempfile.gettempdir() + checkpoint = os.path.join(tmpdir, 'test.pth') + save_checkpoint(model, checkpoint) + cfg = deepcopy(self.cfg) + model = DistilledVisionTransformer(**cfg) + load_checkpoint(model, checkpoint, strict=True) + self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed)) - # Test forward - imgs = torch.rand(1, 3, 224, 224) - outs = model(imgs) - patch_token, cls_token, dist_token = outs[0] - assert patch_token.shape == (1, 768, 14, 14) - assert cls_token.shape == (1, 768) - assert dist_token.shape == (1, 768) + # test load checkpoint with different img_size + cfg = deepcopy(self.cfg) + cfg['img_size'] = 384 + model = DistilledVisionTransformer(**cfg) + load_checkpoint(model, checkpoint, strict=True) + resized_pos_embed = timm_resize_pos_embed( + pretrain_pos_embed, model.pos_embed, num_tokens=2) + self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed)) - # Test multiple out_indices - model = DistilledVisionTransformer( - **cfg_ori, out_indices=(0, 1, 2, 3), output_cls_token=False) - outs = model(imgs) - for out in outs: - assert out.shape == (1, 768, 14, 14) + os.remove(checkpoint) + + def test_forward(self): + imgs = torch.randn(3, 3, 224, 224) + + # test with_cls_token=False + cfg = deepcopy(self.cfg) + cfg['with_cls_token'] = False + cfg['output_cls_token'] = True + with self.assertRaisesRegex(AssertionError, 'but got False'): + DistilledVisionTransformer(**cfg) + + cfg = deepcopy(self.cfg) + cfg['with_cls_token'] = False + cfg['output_cls_token'] = False + model = DistilledVisionTransformer(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token = outs[-1] + self.assertEqual(patch_token.shape, (3, 768, 14, 14)) + + # test with output_cls_token + cfg = deepcopy(self.cfg) + model = DistilledVisionTransformer(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token, cls_token, dist_token = outs[-1] + self.assertEqual(patch_token.shape, (3, 768, 14, 14)) + self.assertEqual(cls_token.shape, (3, 768)) + self.assertEqual(dist_token.shape, (3, 768)) + + # test without output_cls_token + cfg = deepcopy(self.cfg) + cfg['output_cls_token'] = False + model = DistilledVisionTransformer(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token = outs[-1] + self.assertEqual(patch_token.shape, (3, 768, 14, 14)) + + # Test forward with multi out indices + cfg = deepcopy(self.cfg) + cfg['out_indices'] = [-3, -2, -1] + model = DistilledVisionTransformer(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 3) + for out in outs: + patch_token, cls_token, dist_token = out + self.assertEqual(patch_token.shape, (3, 768, 14, 14)) + self.assertEqual(cls_token.shape, (3, 768)) + self.assertEqual(dist_token.shape, (3, 768)) + + # Test forward with dynamic input size + imgs1 = torch.randn(3, 3, 224, 224) + imgs2 = torch.randn(3, 3, 256, 256) + imgs3 = torch.randn(3, 3, 256, 309) + cfg = deepcopy(self.cfg) + model = DistilledVisionTransformer(**cfg) + for imgs in [imgs1, imgs2, imgs3]: + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token, cls_token, dist_token = outs[-1] + expect_feat_shape = (math.ceil(imgs.shape[2] / 16), + math.ceil(imgs.shape[3] / 16)) + self.assertEqual(patch_token.shape, (3, 768, *expect_feat_shape)) + self.assertEqual(cls_token.shape, (3, 768)) + self.assertEqual(dist_token.shape, (3, 768)) diff --git a/tests/test_models/test_backbones/test_mlp_mixer.py b/tests/test_models/test_backbones/test_mlp_mixer.py index cff14dbf..d065a680 100644 --- a/tests/test_models/test_backbones/test_mlp_mixer.py +++ b/tests/test_models/test_backbones/test_mlp_mixer.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from copy import deepcopy +from unittest import TestCase -import pytest import torch from torch.nn.modules import GroupNorm from torch.nn.modules.batchnorm import _BatchNorm @@ -25,51 +25,95 @@ def check_norm_state(modules, train_state): return True -def test_mlp_mixer_backbone(): - cfg_ori = dict( - arch='b', - img_size=224, - patch_size=16, - drop_rate=0.1, - init_cfg=[ +class TestMLPMixer(TestCase): + + def setUp(self): + self.cfg = dict( + arch='b', + img_size=224, + patch_size=16, + drop_rate=0.1, + init_cfg=[ + dict( + type='Kaiming', + layer='Conv2d', + mode='fan_in', + nonlinearity='linear') + ]) + + def test_arch(self): + # Test invalid default arch + with self.assertRaisesRegex(AssertionError, 'not in default archs'): + cfg = deepcopy(self.cfg) + cfg['arch'] = 'unknown' + MlpMixer(**cfg) + + # Test invalid custom arch + with self.assertRaisesRegex(AssertionError, 'Custom arch needs'): + cfg = deepcopy(self.cfg) + cfg['arch'] = { + 'embed_dims': 24, + 'num_layers': 16, + 'tokens_mlp_dims': 4096 + } + MlpMixer(**cfg) + + # Test custom arch + cfg = deepcopy(self.cfg) + cfg['arch'] = { + 'embed_dims': 128, + 'num_layers': 6, + 'tokens_mlp_dims': 256, + 'channels_mlp_dims': 1024 + } + model = MlpMixer(**cfg) + self.assertEqual(model.embed_dims, 128) + self.assertEqual(model.num_layers, 6) + for layer in model.layers: + self.assertEqual(layer.token_mix.feedforward_channels, 256) + self.assertEqual(layer.channel_mix.feedforward_channels, 1024) + + def test_init_weights(self): + # test weight init cfg + cfg = deepcopy(self.cfg) + cfg['init_cfg'] = [ dict( type='Kaiming', layer='Conv2d', mode='fan_in', nonlinearity='linear') - ]) + ] + model = MlpMixer(**cfg) + ori_weight = model.patch_embed.projection.weight.clone().detach() + model.init_weights() + initialized_weight = model.patch_embed.projection.weight + self.assertFalse(torch.allclose(ori_weight, initialized_weight)) - with pytest.raises(AssertionError): - # test invalid arch - cfg = deepcopy(cfg_ori) - cfg['arch'] = 'unknown' - MlpMixer(**cfg) + def test_forward(self): + imgs = torch.randn(3, 3, 224, 224) - with pytest.raises(AssertionError): - # test arch without essential keys - cfg = deepcopy(cfg_ori) - cfg['arch'] = { - 'num_layers': 24, - 'tokens_mlp_dims': 384, - 'channels_mlp_dims': 3072, - } - MlpMixer(**cfg) + # test forward with single out indices + cfg = deepcopy(self.cfg) + model = MlpMixer(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + feat = outs[-1] + self.assertEqual(feat.shape, (3, 768, 196)) - # Test MlpMixer base model with input size of 224 - # and patch size of 16 - model = MlpMixer(**cfg_ori) - model.init_weights() - model.train() + # test forward with multi out indices + cfg = deepcopy(self.cfg) + cfg['out_indices'] = [-3, -2, -1] + model = MlpMixer(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 3) + for feat in outs: + self.assertEqual(feat.shape, (3, 768, 196)) - assert check_norm_state(model.modules(), True) - - imgs = torch.randn(3, 3, 224, 224) - feat = model(imgs)[-1] - assert feat.shape == (3, 768, 196) - - # Test MlpMixer with multi out indices - cfg = deepcopy(cfg_ori) - cfg['out_indices'] = [-3, -2, -1] - model = MlpMixer(**cfg) - for out in model(imgs): - assert out.shape == (3, 768, 196) + # test with invalid input shape + imgs2 = torch.randn(3, 3, 256, 256) + cfg = deepcopy(self.cfg) + model = MlpMixer(**cfg) + with self.assertRaisesRegex(AssertionError, 'dynamic input shape.'): + model(imgs2) diff --git a/tests/test_models/test_backbones/test_swin_transformer.py b/tests/test_models/test_backbones/test_swin_transformer.py index 9a7d2782..b90ac0ed 100644 --- a/tests/test_models/test_backbones/test_swin_transformer.py +++ b/tests/test_models/test_backbones/test_swin_transformer.py @@ -1,16 +1,18 @@ # Copyright (c) OpenMMLab. All rights reserved. +import math import os import tempfile -from math import ceil +from copy import deepcopy +from itertools import chain +from unittest import TestCase -import numpy as np -import pytest import torch from mmcv.runner import load_checkpoint, save_checkpoint from mmcv.utils.parrots_wrapper import _BatchNorm from mmcls.models.backbones import SwinTransformer from mmcls.models.backbones.swin_transformer import SwinBlock +from .utils import timm_resize_pos_embed def check_norm_state(modules, train_state): @@ -22,215 +24,232 @@ def check_norm_state(modules, train_state): return True -def test_assertion(): - """Test Swin Transformer backbone.""" - with pytest.raises(AssertionError): - # Swin Transformer arch string should be in - SwinTransformer(arch='unknown') +class TestSwinTransformer(TestCase): - with pytest.raises(AssertionError): - # Swin Transformer arch dict should include 'embed_dims', - # 'depths' and 'num_head' keys. - SwinTransformer(arch=dict(embed_dims=96, depths=[2, 2, 18, 2])) + def setUp(self): + self.cfg = dict( + arch='b', img_size=224, patch_size=4, drop_path_rate=0.1) + def test_arch(self): + # Test invalid default arch + with self.assertRaisesRegex(AssertionError, 'not in default archs'): + cfg = deepcopy(self.cfg) + cfg['arch'] = 'unknown' + SwinTransformer(**cfg) -def test_forward(): - # Test tiny arch forward - model = SwinTransformer(arch='Tiny') - model.init_weights() - model.train() + # Test invalid custom arch + with self.assertRaisesRegex(AssertionError, 'Custom arch needs'): + cfg = deepcopy(self.cfg) + cfg['arch'] = { + 'embed_dims': 96, + 'num_heads': [3, 6, 12, 16], + } + SwinTransformer(**cfg) - imgs = torch.randn(1, 3, 224, 224) - output = model(imgs) - assert len(output) == 1 - assert output[0].shape == (1, 768, 7, 7) + # Test custom arch + cfg = deepcopy(self.cfg) + depths = [2, 2, 4, 2] + num_heads = [6, 12, 6, 12] + cfg['arch'] = { + 'embed_dims': 256, + 'depths': depths, + 'num_heads': num_heads + } + model = SwinTransformer(**cfg) + for i, stage in enumerate(model.stages): + self.assertEqual(stage.embed_dims, 256 * (2**i)) + self.assertEqual(len(stage.blocks), depths[i]) + self.assertEqual(stage.blocks[0].attn.w_msa.num_heads, + num_heads[i]) - # Test small arch forward - model = SwinTransformer(arch='small') - model.init_weights() - model.train() + def test_init_weights(self): + # test weight init cfg + cfg = deepcopy(self.cfg) + cfg['use_abs_pos_embed'] = True + cfg['init_cfg'] = [ + dict( + type='Kaiming', + layer='Conv2d', + mode='fan_in', + nonlinearity='linear') + ] + model = SwinTransformer(**cfg) + ori_weight = model.patch_embed.projection.weight.clone().detach() + # The pos_embed is all zero before initialize + self.assertTrue( + torch.allclose(model.absolute_pos_embed, torch.tensor(0.))) - imgs = torch.randn(1, 3, 224, 224) - output = model(imgs) - assert len(output) == 1 - assert output[0].shape == (1, 768, 7, 7) + model.init_weights() + initialized_weight = model.patch_embed.projection.weight + self.assertFalse(torch.allclose(ori_weight, initialized_weight)) + self.assertFalse( + torch.allclose(model.absolute_pos_embed, torch.tensor(0.))) - # Test base arch forward - model = SwinTransformer(arch='B') - model.init_weights() - model.train() + pretrain_pos_embed = model.absolute_pos_embed.clone().detach() - imgs = torch.randn(1, 3, 224, 224) - output = model(imgs) - assert len(output) == 1 - assert output[0].shape == (1, 1024, 7, 7) + tmpdir = tempfile.gettempdir() + # Save v3 checkpoints + checkpoint_v2 = os.path.join(tmpdir, 'v3.pth') + save_checkpoint(model, checkpoint_v2) + # Save v1 checkpoints + setattr(model, 'norm', model.norm3) + setattr(model.stages[0].blocks[1].attn, 'attn_mask', + torch.zeros(64, 49, 49)) + model._version = 1 + del model.norm3 + checkpoint_v1 = os.path.join(tmpdir, 'v1.pth') + save_checkpoint(model, checkpoint_v1) - # Test large arch forward - model = SwinTransformer(arch='l') - model.init_weights() - model.train() + # test load v1 checkpoint + cfg = deepcopy(self.cfg) + cfg['use_abs_pos_embed'] = True + model = SwinTransformer(**cfg) + load_checkpoint(model, checkpoint_v1, strict=True) - imgs = torch.randn(1, 3, 224, 224) - output = model(imgs) - assert len(output) == 1 - assert output[0].shape == (1, 1536, 7, 7) + # test load v3 checkpoint + cfg = deepcopy(self.cfg) + cfg['use_abs_pos_embed'] = True + model = SwinTransformer(**cfg) + load_checkpoint(model, checkpoint_v2, strict=True) - # Test base arch with window_size=12, image_size=384 - model = SwinTransformer( - arch='base', - img_size=384, - stage_cfgs=dict(block_cfgs=dict(window_size=12))) - model.init_weights() - model.train() + # test load v3 checkpoint with different img_size + cfg = deepcopy(self.cfg) + cfg['img_size'] = 384 + cfg['use_abs_pos_embed'] = True + model = SwinTransformer(**cfg) + load_checkpoint(model, checkpoint_v2, strict=True) + resized_pos_embed = timm_resize_pos_embed( + pretrain_pos_embed, model.absolute_pos_embed, num_tokens=0) + self.assertTrue( + torch.allclose(model.absolute_pos_embed, resized_pos_embed)) - imgs = torch.randn(1, 3, 384, 384) - output = model(imgs) - assert len(output) == 1 - assert output[0].shape == (1, 1024, 12, 12) + os.remove(checkpoint_v1) + os.remove(checkpoint_v2) - # Test multiple output indices - imgs = torch.randn(1, 3, 224, 224) - model = SwinTransformer(arch='base', out_indices=(0, 1, 2, 3)) - outs = model(imgs) - assert outs[0].shape == (1, 256, 28, 28) - assert outs[1].shape == (1, 512, 14, 14) - assert outs[2].shape == (1, 1024, 7, 7) - assert outs[3].shape == (1, 1024, 7, 7) + def test_forward(self): + imgs = torch.randn(3, 3, 224, 224) - # Test base arch with with checkpoint forward - model = SwinTransformer(arch='B', with_cp=True) - for m in model.modules(): - if isinstance(m, SwinBlock): - assert m.with_cp - model.init_weights() - model.train() + cfg = deepcopy(self.cfg) + model = SwinTransformer(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + feat = outs[-1] + self.assertEqual(feat.shape, (3, 1024, 7, 7)) - imgs = torch.randn(1, 3, 224, 224) - output = model(imgs) - assert len(output) == 1 - assert output[0].shape == (1, 1024, 7, 7) + # test with window_size=12 + cfg = deepcopy(self.cfg) + cfg['window_size'] = 12 + model = SwinTransformer(**cfg) + outs = model(torch.randn(3, 3, 384, 384)) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + feat = outs[-1] + self.assertEqual(feat.shape, (3, 1024, 12, 12)) + with self.assertRaisesRegex(AssertionError, r'the window size \(12\)'): + model(torch.randn(3, 3, 224, 224)) + # test with pad_small_map=True + cfg = deepcopy(self.cfg) + cfg['window_size'] = 12 + cfg['pad_small_map'] = True + model = SwinTransformer(**cfg) + outs = model(torch.randn(3, 3, 224, 224)) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + feat = outs[-1] + self.assertEqual(feat.shape, (3, 1024, 7, 7)) -def test_structure(): - # Test small with use_abs_pos_embed = True - model = SwinTransformer(arch='small', use_abs_pos_embed=True) - assert model.absolute_pos_embed.shape == (1, 3136, 96) + # test multiple output indices + cfg = deepcopy(self.cfg) + cfg['out_indices'] = (0, 1, 2, 3) + model = SwinTransformer(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 4) + for stride, out in zip([2, 4, 8, 8], outs): + self.assertEqual(out.shape, + (3, 128 * stride, 56 // stride, 56 // stride)) - # Test small with use_abs_pos_embed = False - model = SwinTransformer(arch='small', use_abs_pos_embed=False) - assert not hasattr(model, 'absolute_pos_embed') + # test with checkpoint forward + cfg = deepcopy(self.cfg) + cfg['with_cp'] = True + model = SwinTransformer(**cfg) + for m in model.modules(): + if isinstance(m, SwinBlock): + self.assertTrue(m.with_cp) + model.init_weights() + model.train() - # Test small with auto_pad = True - model = SwinTransformer( - arch='small', - auto_pad=True, - stage_cfgs=dict( - block_cfgs={'window_size': 7}, - downsample_cfg={ - 'kernel_size': (3, 2), - })) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + feat = outs[-1] + self.assertEqual(feat.shape, (3, 1024, 7, 7)) - # stage 1 - input_h = int(224 / 4 / 3) - expect_h = ceil(input_h / 7) * 7 - input_w = int(224 / 4 / 2) - expect_w = ceil(input_w / 7) * 7 - assert model.stages[1].blocks[0].attn.pad_b == expect_h - input_h - assert model.stages[1].blocks[0].attn.pad_r == expect_w - input_w + # test with dynamic input shape + imgs1 = torch.randn(3, 3, 224, 224) + imgs2 = torch.randn(3, 3, 256, 256) + imgs3 = torch.randn(3, 3, 256, 309) + cfg = deepcopy(self.cfg) + model = SwinTransformer(**cfg) + for imgs in [imgs1, imgs2, imgs3]: + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + feat = outs[-1] + expect_feat_shape = (math.ceil(imgs.shape[2] / 32), + math.ceil(imgs.shape[3] / 32)) + self.assertEqual(feat.shape, (3, 1024, *expect_feat_shape)) - # stage 2 - input_h = int(224 / 4 / 3 / 3) - # input_h is smaller than window_size, shrink the window_size to input_h. - expect_h = input_h - input_w = int(224 / 4 / 2 / 2) - expect_w = ceil(input_w / input_h) * input_h - assert model.stages[2].blocks[0].attn.pad_b == expect_h - input_h - assert model.stages[2].blocks[0].attn.pad_r == expect_w - input_w + def test_structure(self): + # test drop_path_rate decay + cfg = deepcopy(self.cfg) + cfg['drop_path_rate'] = 0.2 + model = SwinTransformer(**cfg) + depths = model.arch_settings['depths'] + blocks = chain(*[stage.blocks for stage in model.stages]) + for i, block in enumerate(blocks): + expect_prob = 0.2 / (sum(depths) - 1) * i + self.assertAlmostEqual(block.ffn.dropout_layer.drop_prob, + expect_prob) + self.assertAlmostEqual(block.attn.drop.drop_prob, expect_prob) - # stage 3 - input_h = int(224 / 4 / 3 / 3 / 3) - expect_h = input_h - input_w = int(224 / 4 / 2 / 2 / 2) - expect_w = ceil(input_w / input_h) * input_h - assert model.stages[3].blocks[0].attn.pad_b == expect_h - input_h - assert model.stages[3].blocks[0].attn.pad_r == expect_w - input_w + # test Swin-Transformer with norm_eval=True + cfg = deepcopy(self.cfg) + cfg['norm_eval'] = True + cfg['norm_cfg'] = dict(type='BN') + cfg['stage_cfgs'] = dict(block_cfgs=dict(norm_cfg=dict(type='BN'))) + model = SwinTransformer(**cfg) + model.init_weights() + model.train() + self.assertTrue(check_norm_state(model.modules(), False)) - # Test small with auto_pad = False - with pytest.raises(AssertionError): - model = SwinTransformer( - arch='small', - auto_pad=False, - stage_cfgs=dict( - block_cfgs={'window_size': 7}, - downsample_cfg={ - 'kernel_size': (3, 2), - })) + # test Swin-Transformer with first stage frozen. + cfg = deepcopy(self.cfg) + frozen_stages = 0 + cfg['frozen_stages'] = frozen_stages + cfg['out_indices'] = (0, 1, 2, 3) + model = SwinTransformer(**cfg) + model.init_weights() + model.train() - # Test drop_path_rate decay - model = SwinTransformer( - arch='small', - drop_path_rate=0.2, - ) - depths = model.arch_settings['depths'] - pos = 0 - for i, depth in enumerate(depths): - for j in range(depth): - block = model.stages[i].blocks[j] - expect_prob = 0.2 / (sum(depths) - 1) * pos - assert np.isclose(block.ffn.dropout_layer.drop_prob, expect_prob) - assert np.isclose(block.attn.drop.drop_prob, expect_prob) - pos += 1 + # the patch_embed and first stage should not require grad. + self.assertFalse(model.patch_embed.training) + for param in model.patch_embed.parameters(): + self.assertFalse(param.requires_grad) + for i in range(frozen_stages + 1): + stage = model.stages[i] + for param in stage.parameters(): + self.assertFalse(param.requires_grad) + for param in model.norm0.parameters(): + self.assertFalse(param.requires_grad) - # Test Swin-Transformer with norm_eval=True - model = SwinTransformer( - arch='small', - norm_eval=True, - norm_cfg=dict(type='BN'), - stage_cfgs=dict(block_cfgs=dict(norm_cfg=dict(type='BN'))), - ) - model.init_weights() - model.train() - assert check_norm_state(model.modules(), False) - - # Test Swin-Transformer with first stage frozen. - frozen_stages = 0 - model = SwinTransformer( - arch='small', frozen_stages=frozen_stages, out_indices=(0, 1, 2, 3)) - model.init_weights() - model.train() - - assert model.patch_embed.training is False - for param in model.patch_embed.parameters(): - assert param.requires_grad is False - for i in range(frozen_stages + 1): - stage = model.stages[i] - for param in stage.parameters(): - assert param.requires_grad is False - for param in model.norm0.parameters(): - assert param.requires_grad is False - - # the second stage should require grad. - stage = model.stages[1] - for param in stage.parameters(): - assert param.requires_grad is True - for param in model.norm1.parameters(): - assert param.requires_grad is True - - -def test_load_checkpoint(): - model = SwinTransformer(arch='tiny') - ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth') - - assert model._version == 2 - - # test load v2 checkpoint - save_checkpoint(model, ckpt_path) - load_checkpoint(model, ckpt_path, strict=True) - - # test load v1 checkpoint - setattr(model, 'norm', model.norm3) - model._version = 1 - del model.norm3 - save_checkpoint(model, ckpt_path) - model = SwinTransformer(arch='tiny') - load_checkpoint(model, ckpt_path, strict=True) + # the second stage should require grad. + for i in range(frozen_stages + 1, 4): + stage = model.stages[i] + for param in stage.parameters(): + self.assertTrue(param.requires_grad) + norm = getattr(model, f'norm{i}') + for param in norm.parameters(): + self.assertTrue(param.requires_grad) diff --git a/tests/test_models/test_backbones/test_t2t_vit.py b/tests/test_models/test_backbones/test_t2t_vit.py index e15f92f9..cc7e839c 100644 --- a/tests/test_models/test_backbones/test_t2t_vit.py +++ b/tests/test_models/test_backbones/test_t2t_vit.py @@ -1,84 +1,157 @@ # Copyright (c) OpenMMLab. All rights reserved. +import math +import os +import tempfile from copy import deepcopy +from unittest import TestCase -import pytest import torch -from torch.nn.modules import GroupNorm -from torch.nn.modules.batchnorm import _BatchNorm +from mmcv.runner import load_checkpoint, save_checkpoint from mmcls.models.backbones import T2T_ViT +from .utils import timm_resize_pos_embed -def is_norm(modules): - """Check if is one of the norms.""" - if isinstance(modules, (GroupNorm, _BatchNorm)): - return True - return False +class TestT2TViT(TestCase): + def setUp(self): + self.cfg = dict( + img_size=224, + in_channels=3, + embed_dims=384, + t2t_cfg=dict( + token_dims=64, + use_performer=False, + ), + num_layers=14, + drop_path_rate=0.1) -def check_norm_state(modules, train_state): - """Check if norm layer is in correct train state.""" - for mod in modules: - if isinstance(mod, _BatchNorm): - if mod.training != train_state: - return False - return True - - -def test_vit_backbone(): - - cfg_ori = dict( - img_size=224, - in_channels=3, - embed_dims=384, - t2t_cfg=dict( - token_dims=64, - use_performer=False, - ), - num_layers=14, - layer_cfgs=dict( - num_heads=6, - feedforward_channels=3 * 384, # mlp_ratio = 3 - ), - drop_path_rate=0.1, - init_cfg=[ - dict(type='TruncNormal', layer='Linear', std=.02), - dict(type='Constant', layer='LayerNorm', val=1., bias=0.), - ]) - - with pytest.raises(NotImplementedError): - # test if use performer - cfg = deepcopy(cfg_ori) + def test_structure(self): + # The performer hasn't been implemented + cfg = deepcopy(self.cfg) cfg['t2t_cfg']['use_performer'] = True - T2T_ViT(**cfg) + with self.assertRaises(NotImplementedError): + T2T_ViT(**cfg) - # Test T2T-ViT model with input size of 224 - model = T2T_ViT(**cfg_ori) - model.init_weights() - model.train() + # Test out_indices + cfg = deepcopy(self.cfg) + cfg['out_indices'] = {1: 1} + with self.assertRaisesRegex(AssertionError, "get "): + T2T_ViT(**cfg) + cfg['out_indices'] = [0, 15] + with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 15'): + T2T_ViT(**cfg) - assert check_norm_state(model.modules(), True) + # Test model structure + cfg = deepcopy(self.cfg) + model = T2T_ViT(**cfg) + self.assertEqual(len(model.encoder), 14) + dpr_inc = 0.1 / (14 - 1) + dpr = 0 + for layer in model.encoder: + self.assertEqual(layer.attn.embed_dims, 384) + # The default mlp_ratio is 3 + self.assertEqual(layer.ffn.feedforward_channels, 384 * 3) + self.assertAlmostEqual(layer.attn.out_drop.drop_prob, dpr) + self.assertAlmostEqual(layer.ffn.dropout_layer.drop_prob, dpr) + dpr += dpr_inc - imgs = torch.randn(3, 3, 224, 224) - patch_token, cls_token = model(imgs)[-1] - assert cls_token.shape == (3, 384) - assert patch_token.shape == (3, 384, 14, 14) + def test_init_weights(self): + # test weight init cfg + cfg = deepcopy(self.cfg) + cfg['init_cfg'] = [dict(type='TruncNormal', layer='Linear', std=.02)] + model = T2T_ViT(**cfg) + ori_weight = model.tokens_to_token.project.weight.clone().detach() - # Test custom arch T2T-ViT without output cls token - cfg = deepcopy(cfg_ori) - cfg['embed_dims'] = 256 - cfg['num_layers'] = 16 - cfg['layer_cfgs'] = dict(num_heads=8, feedforward_channels=1024) - cfg['output_cls_token'] = False + model.init_weights() + initialized_weight = model.tokens_to_token.project.weight + self.assertFalse(torch.allclose(ori_weight, initialized_weight)) - model = T2T_ViT(**cfg) - patch_token = model(imgs)[-1] - assert patch_token.shape == (3, 256, 14, 14) + # test load checkpoint + pretrain_pos_embed = model.pos_embed.clone().detach() + tmpdir = tempfile.gettempdir() + checkpoint = os.path.join(tmpdir, 'test.pth') + save_checkpoint(model, checkpoint) + cfg = deepcopy(self.cfg) + model = T2T_ViT(**cfg) + load_checkpoint(model, checkpoint, strict=True) + self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed)) - # Test T2T_ViT with multi out indices - cfg = deepcopy(cfg_ori) - cfg['out_indices'] = [-3, -2, -1] - model = T2T_ViT(**cfg) - for out in model(imgs): - assert out[0].shape == (3, 384, 14, 14) - assert out[1].shape == (3, 384) + # test load checkpoint with different img_size + cfg = deepcopy(self.cfg) + cfg['img_size'] = 384 + model = T2T_ViT(**cfg) + load_checkpoint(model, checkpoint, strict=True) + resized_pos_embed = timm_resize_pos_embed(pretrain_pos_embed, + model.pos_embed) + self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed)) + + os.remove(checkpoint) + + def test_forward(self): + imgs = torch.randn(3, 3, 224, 224) + + # test with_cls_token=False + cfg = deepcopy(self.cfg) + cfg['with_cls_token'] = False + cfg['output_cls_token'] = True + with self.assertRaisesRegex(AssertionError, 'but got False'): + T2T_ViT(**cfg) + + cfg = deepcopy(self.cfg) + cfg['with_cls_token'] = False + cfg['output_cls_token'] = False + model = T2T_ViT(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token = outs[-1] + self.assertEqual(patch_token.shape, (3, 384, 14, 14)) + + # test with output_cls_token + cfg = deepcopy(self.cfg) + model = T2T_ViT(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token, cls_token = outs[-1] + self.assertEqual(patch_token.shape, (3, 384, 14, 14)) + self.assertEqual(cls_token.shape, (3, 384)) + + # test without output_cls_token + cfg = deepcopy(self.cfg) + cfg['output_cls_token'] = False + model = T2T_ViT(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token = outs[-1] + self.assertEqual(patch_token.shape, (3, 384, 14, 14)) + + # Test forward with multi out indices + cfg = deepcopy(self.cfg) + cfg['out_indices'] = [-3, -2, -1] + model = T2T_ViT(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 3) + for out in outs: + patch_token, cls_token = out + self.assertEqual(patch_token.shape, (3, 384, 14, 14)) + self.assertEqual(cls_token.shape, (3, 384)) + + # Test forward with dynamic input size + imgs1 = torch.randn(3, 3, 224, 224) + imgs2 = torch.randn(3, 3, 256, 256) + imgs3 = torch.randn(3, 3, 256, 309) + cfg = deepcopy(self.cfg) + model = T2T_ViT(**cfg) + for imgs in [imgs1, imgs2, imgs3]: + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token, cls_token = outs[-1] + expect_feat_shape = (math.ceil(imgs.shape[2] / 16), + math.ceil(imgs.shape[3] / 16)) + self.assertEqual(patch_token.shape, (3, 384, *expect_feat_shape)) + self.assertEqual(cls_token.shape, (3, 384)) diff --git a/tests/test_models/test_backbones/test_vision_transformer.py b/tests/test_models/test_backbones/test_vision_transformer.py index efa7375c..26cc7370 100644 --- a/tests/test_models/test_backbones/test_vision_transformer.py +++ b/tests/test_models/test_backbones/test_vision_transformer.py @@ -3,160 +3,181 @@ import math import os import tempfile from copy import deepcopy +from unittest import TestCase -import pytest import torch -import torch.nn.functional as F -from torch.nn.modules import GroupNorm -from torch.nn.modules.batchnorm import _BatchNorm +from mmcv.runner import load_checkpoint, save_checkpoint from mmcls.models.backbones import VisionTransformer +from .utils import timm_resize_pos_embed -def is_norm(modules): - """Check if is one of the norms.""" - if isinstance(modules, (GroupNorm, _BatchNorm)): - return True - return False +class TestVisionTransformer(TestCase): + def setUp(self): + self.cfg = dict( + arch='b', img_size=224, patch_size=16, drop_path_rate=0.1) -def check_norm_state(modules, train_state): - """Check if norm layer is in correct train state.""" - for mod in modules: - if isinstance(mod, _BatchNorm): - if mod.training != train_state: - return False - return True + def test_structure(self): + # Test invalid default arch + with self.assertRaisesRegex(AssertionError, 'not in default archs'): + cfg = deepcopy(self.cfg) + cfg['arch'] = 'unknown' + VisionTransformer(**cfg) + # Test invalid custom arch + with self.assertRaisesRegex(AssertionError, 'Custom arch needs'): + cfg = deepcopy(self.cfg) + cfg['arch'] = { + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096 + } + VisionTransformer(**cfg) -def test_vit_backbone(): + # Test custom arch + cfg = deepcopy(self.cfg) + cfg['arch'] = { + 'embed_dims': 128, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 1024 + } + model = VisionTransformer(**cfg) + self.assertEqual(model.embed_dims, 128) + self.assertEqual(model.num_layers, 24) + for layer in model.layers: + self.assertEqual(layer.attn.num_heads, 16) + self.assertEqual(layer.ffn.feedforward_channels, 1024) - cfg_ori = dict( - arch='b', - img_size=224, - patch_size=16, - drop_rate=0.1, - init_cfg=[ + # Test out_indices + cfg = deepcopy(self.cfg) + cfg['out_indices'] = {1: 1} + with self.assertRaisesRegex(AssertionError, "get "): + VisionTransformer(**cfg) + cfg['out_indices'] = [0, 13] + with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'): + VisionTransformer(**cfg) + + # Test model structure + cfg = deepcopy(self.cfg) + model = VisionTransformer(**cfg) + self.assertEqual(len(model.layers), 12) + dpr_inc = 0.1 / (12 - 1) + dpr = 0 + for layer in model.layers: + self.assertEqual(layer.attn.embed_dims, 768) + self.assertEqual(layer.attn.num_heads, 12) + self.assertEqual(layer.ffn.feedforward_channels, 3072) + self.assertAlmostEqual(layer.attn.out_drop.drop_prob, dpr) + self.assertAlmostEqual(layer.ffn.dropout_layer.drop_prob, dpr) + dpr += dpr_inc + + def test_init_weights(self): + # test weight init cfg + cfg = deepcopy(self.cfg) + cfg['init_cfg'] = [ dict( type='Kaiming', layer='Conv2d', mode='fan_in', nonlinearity='linear') - ]) + ] + model = VisionTransformer(**cfg) + ori_weight = model.patch_embed.projection.weight.clone().detach() + # The pos_embed is all zero before initialize + self.assertTrue(torch.allclose(model.pos_embed, torch.tensor(0.))) - with pytest.raises(AssertionError): - # test invalid arch - cfg = deepcopy(cfg_ori) - cfg['arch'] = 'unknown' - VisionTransformer(**cfg) + model.init_weights() + initialized_weight = model.patch_embed.projection.weight + self.assertFalse(torch.allclose(ori_weight, initialized_weight)) + self.assertFalse(torch.allclose(model.pos_embed, torch.tensor(0.))) - with pytest.raises(AssertionError): - # test arch without essential keys - cfg = deepcopy(cfg_ori) - cfg['arch'] = { - 'num_layers': 24, - 'num_heads': 16, - 'feedforward_channels': 4096 - } - VisionTransformer(**cfg) + # test load checkpoint + pretrain_pos_embed = model.pos_embed.clone().detach() + tmpdir = tempfile.gettempdir() + checkpoint = os.path.join(tmpdir, 'test.pth') + save_checkpoint(model, checkpoint) + cfg = deepcopy(self.cfg) + model = VisionTransformer(**cfg) + load_checkpoint(model, checkpoint, strict=True) + self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed)) - # Test ViT base model with input size of 224 - # and patch size of 16 - model = VisionTransformer(**cfg_ori) - model.init_weights() - model.train() + # test load checkpoint with different img_size + cfg = deepcopy(self.cfg) + cfg['img_size'] = 384 + model = VisionTransformer(**cfg) + load_checkpoint(model, checkpoint, strict=True) + resized_pos_embed = timm_resize_pos_embed(pretrain_pos_embed, + model.pos_embed) + self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed)) - assert check_norm_state(model.modules(), True) + os.remove(checkpoint) - imgs = torch.randn(3, 3, 224, 224) - patch_token, cls_token = model(imgs)[-1] - assert cls_token.shape == (3, 768) - assert patch_token.shape == (3, 768, 14, 14) + def test_forward(self): + imgs = torch.randn(3, 3, 224, 224) - # Test custom arch ViT without output cls token - cfg = deepcopy(cfg_ori) - cfg['arch'] = { - 'embed_dims': 128, - 'num_layers': 24, - 'num_heads': 16, - 'feedforward_channels': 1024 - } - cfg['output_cls_token'] = False - model = VisionTransformer(**cfg) - patch_token = model(imgs)[-1] - assert patch_token.shape == (3, 128, 14, 14) + # test with_cls_token=False + cfg = deepcopy(self.cfg) + cfg['with_cls_token'] = False + cfg['output_cls_token'] = True + with self.assertRaisesRegex(AssertionError, 'but got False'): + VisionTransformer(**cfg) - # Test ViT with multi out indices - cfg = deepcopy(cfg_ori) - cfg['out_indices'] = [-3, -2, -1] - model = VisionTransformer(**cfg) - for out in model(imgs): - assert out[0].shape == (3, 768, 14, 14) - assert out[1].shape == (3, 768) + cfg = deepcopy(self.cfg) + cfg['with_cls_token'] = False + cfg['output_cls_token'] = False + model = VisionTransformer(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token = outs[-1] + self.assertEqual(patch_token.shape, (3, 768, 14, 14)) + # test with output_cls_token + cfg = deepcopy(self.cfg) + model = VisionTransformer(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token, cls_token = outs[-1] + self.assertEqual(patch_token.shape, (3, 768, 14, 14)) + self.assertEqual(cls_token.shape, (3, 768)) -def timm_resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): - # Timm version pos embed resize function. - # Refers to https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # noqa:E501 - ntok_new = posemb_new.shape[1] - if num_tokens: - posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, - num_tokens:] - ntok_new -= num_tokens - else: - posemb_tok, posemb_grid = posemb[:, :0], posemb[0] - gs_old = int(math.sqrt(len(posemb_grid))) - if not len(gs_new): # backwards compatibility - gs_new = [int(math.sqrt(ntok_new))] * 2 - assert len(gs_new) >= 2 - posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, - -1).permute(0, 3, 1, 2) - posemb_grid = F.interpolate( - posemb_grid, size=gs_new, mode='bicubic', align_corners=False) - posemb_grid = posemb_grid.permute(0, 2, 3, - 1).reshape(1, gs_new[0] * gs_new[1], -1) - posemb = torch.cat([posemb_tok, posemb_grid], dim=1) - return posemb + # test without output_cls_token + cfg = deepcopy(self.cfg) + cfg['output_cls_token'] = False + model = VisionTransformer(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token = outs[-1] + self.assertEqual(patch_token.shape, (3, 768, 14, 14)) + # Test forward with multi out indices + cfg = deepcopy(self.cfg) + cfg['out_indices'] = [-3, -2, -1] + model = VisionTransformer(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 3) + for out in outs: + patch_token, cls_token = out + self.assertEqual(patch_token.shape, (3, 768, 14, 14)) + self.assertEqual(cls_token.shape, (3, 768)) -def test_vit_weight_init(): - # test weight init cfg - pretrain_cfg = dict( - arch='b', - img_size=224, - patch_size=16, - init_cfg=[dict(type='Constant', val=1., layer='Conv2d')]) - pretrain_model = VisionTransformer(**pretrain_cfg) - pretrain_model.init_weights() - assert torch.allclose(pretrain_model.patch_embed.projection.weight, - torch.tensor(1.)) - assert pretrain_model.pos_embed.abs().sum() > 0 - - pos_embed_weight = pretrain_model.pos_embed.detach() - tmpdir = tempfile.gettempdir() - checkpoint = os.path.join(tmpdir, 'test.pth') - torch.save(pretrain_model.state_dict(), checkpoint) - - # test load checkpoint - finetune_cfg = dict( - arch='b', - img_size=224, - patch_size=16, - init_cfg=dict(type='Pretrained', checkpoint=checkpoint)) - finetune_model = VisionTransformer(**finetune_cfg) - finetune_model.init_weights() - assert torch.allclose(finetune_model.pos_embed, pos_embed_weight) - - # test load checkpoint with different img_size - finetune_cfg = dict( - arch='b', - img_size=384, - patch_size=16, - init_cfg=dict(type='Pretrained', checkpoint=checkpoint)) - finetune_model = VisionTransformer(**finetune_cfg) - finetune_model.init_weights() - resized_pos_embed = timm_resize_pos_embed(pos_embed_weight, - finetune_model.pos_embed) - assert torch.allclose(finetune_model.pos_embed, resized_pos_embed) - - os.remove(checkpoint) + # Test forward with dynamic input size + imgs1 = torch.randn(3, 3, 224, 224) + imgs2 = torch.randn(3, 3, 256, 256) + imgs3 = torch.randn(3, 3, 256, 309) + cfg = deepcopy(self.cfg) + model = VisionTransformer(**cfg) + for imgs in [imgs1, imgs2, imgs3]: + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token, cls_token = outs[-1] + expect_feat_shape = (math.ceil(imgs.shape[2] / 16), + math.ceil(imgs.shape[3] / 16)) + self.assertEqual(patch_token.shape, (3, 768, *expect_feat_shape)) + self.assertEqual(cls_token.shape, (3, 768)) diff --git a/tests/test_models/test_backbones/utils.py b/tests/test_models/test_backbones/utils.py new file mode 100644 index 00000000..aba9cafb --- /dev/null +++ b/tests/test_models/test_backbones/utils.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn.functional as F + + +def timm_resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): + """Timm version pos embed resize function. + + copied from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + """ # noqa:E501 + ntok_new = posemb_new.shape[1] + if num_tokens: + posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, + num_tokens:] + ntok_new -= num_tokens + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, + -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate( + posemb_grid, size=gs_new, mode='bicubic', align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, + 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + return posemb diff --git a/tests/test_models/test_utils/test_attention.py b/tests/test_models/test_utils/test_attention.py index 271df90f..9626f66f 100644 --- a/tests/test_models/test_utils/test_attention.py +++ b/tests/test_models/test_utils/test_attention.py @@ -1,5 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -import numpy as np +from unittest import TestCase +from unittest.mock import ANY, MagicMock + +import pytest import torch from mmcls.models.utils.attention import ShiftWindowMSA, WindowMSA @@ -22,157 +25,177 @@ def get_relative_position_index(window_size): return relative_position_index -def test_window_msa(): - batch_size = 1 - num_windows = (4, 4) - embed_dims = 96 - window_size = (7, 7) - num_heads = 4 - attn = WindowMSA( - embed_dims=embed_dims, window_size=window_size, num_heads=num_heads) - inputs = torch.rand((batch_size * num_windows[0] * num_windows[1], - window_size[0] * window_size[1], embed_dims)) +class TestWindowMSA(TestCase): - # test forward - output = attn(inputs) - assert output.shape == inputs.shape - assert attn.relative_position_bias_table.shape == ( - (2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + def test_forward(self): + attn = WindowMSA(embed_dims=96, window_size=(7, 7), num_heads=4) + inputs = torch.rand((16, 7 * 7, 96)) + output = attn(inputs) + self.assertEqual(output.shape, inputs.shape) - # test relative_position_bias_table init - attn.init_weights() - assert abs(attn.relative_position_bias_table).sum() > 0 + # test non-square window_size + attn = WindowMSA(embed_dims=96, window_size=(6, 7), num_heads=4) + inputs = torch.rand((16, 6 * 7, 96)) + output = attn(inputs) + self.assertEqual(output.shape, inputs.shape) - # test non-square window_size - window_size = (6, 7) - attn = WindowMSA( - embed_dims=embed_dims, window_size=window_size, num_heads=num_heads) - inputs = torch.rand((batch_size * num_windows[0] * num_windows[1], - window_size[0] * window_size[1], embed_dims)) - output = attn(inputs) - assert output.shape == inputs.shape + def test_relative_pos_embed(self): + attn = WindowMSA(embed_dims=96, window_size=(7, 8), num_heads=4) + self.assertEqual(attn.relative_position_bias_table.shape, + ((2 * 7 - 1) * (2 * 8 - 1), 4)) + # test relative_position_index + expected_rel_pos_index = get_relative_position_index((7, 8)) + self.assertTrue( + torch.allclose(attn.relative_position_index, + expected_rel_pos_index)) - # test relative_position_index - expected_rel_pos_index = get_relative_position_index(window_size) - assert (attn.relative_position_index == expected_rel_pos_index).all() + # test default init + self.assertTrue( + torch.allclose(attn.relative_position_bias_table, + torch.tensor(0.))) + attn.init_weights() + self.assertFalse( + torch.allclose(attn.relative_position_bias_table, + torch.tensor(0.))) - # test qkv_bias=True - attn = WindowMSA( - embed_dims=embed_dims, - window_size=window_size, - num_heads=num_heads, - qkv_bias=True) - assert attn.qkv.bias.shape == (embed_dims * 3, ) + def test_qkv_bias(self): + # test qkv_bias=True + attn = WindowMSA( + embed_dims=96, window_size=(7, 7), num_heads=4, qkv_bias=True) + self.assertEqual(attn.qkv.bias.shape, (96 * 3, )) - # test qkv_bias=False - attn = WindowMSA( - embed_dims=embed_dims, - window_size=window_size, - num_heads=num_heads, - qkv_bias=False) - assert attn.qkv.bias is None + # test qkv_bias=False + attn = WindowMSA( + embed_dims=96, window_size=(7, 7), num_heads=4, qkv_bias=False) + self.assertIsNone(attn.qkv.bias) - # test default qk_scale - attn = WindowMSA( - embed_dims=embed_dims, - window_size=window_size, - num_heads=num_heads, - qk_scale=None) - head_dims = embed_dims // num_heads - assert np.isclose(attn.scale, head_dims**-0.5) + def tets_qk_scale(self): + # test default qk_scale + attn = WindowMSA( + embed_dims=96, window_size=(7, 7), num_heads=4, qk_scale=None) + head_dims = 96 // 4 + self.assertAlmostEqual(attn.scale, head_dims**-0.5) - # test specified qk_scale - attn = WindowMSA( - embed_dims=embed_dims, - window_size=window_size, - num_heads=num_heads, - qk_scale=0.3) - assert attn.scale == 0.3 + # test specified qk_scale + attn = WindowMSA( + embed_dims=96, window_size=(7, 7), num_heads=4, qk_scale=0.3) + self.assertEqual(attn.scale, 0.3) - # test attn_drop - attn = WindowMSA( - embed_dims=embed_dims, - window_size=window_size, - num_heads=num_heads, - attn_drop=1.0) - inputs = torch.rand((batch_size * num_windows[0] * num_windows[1], - window_size[0] * window_size[1], embed_dims)) - # drop all attn output, output shuold be equal to proj.bias - assert torch.allclose(attn(inputs), attn.proj.bias) + def test_attn_drop(self): + inputs = torch.rand(16, 7 * 7, 96) + attn = WindowMSA( + embed_dims=96, window_size=(7, 7), num_heads=4, attn_drop=1.0) + # drop all attn output, output shuold be equal to proj.bias + self.assertTrue(torch.allclose(attn(inputs), attn.proj.bias)) - # test prob_drop - attn = WindowMSA( - embed_dims=embed_dims, - window_size=window_size, - num_heads=num_heads, - proj_drop=1.0) - assert (attn(inputs) == 0).all() + def test_prob_drop(self): + inputs = torch.rand(16, 7 * 7, 96) + attn = WindowMSA( + embed_dims=96, window_size=(7, 7), num_heads=4, proj_drop=1.0) + self.assertTrue(torch.allclose(attn(inputs), torch.tensor(0.))) + + def test_mask(self): + inputs = torch.rand(16, 7 * 7, 96) + attn = WindowMSA(embed_dims=96, window_size=(7, 7), num_heads=4) + mask = torch.zeros((4, 49, 49)) + # Mask the first column + mask[:, 0, :] = -100 + mask[:, :, 0] = -100 + outs = attn(inputs, mask=mask) + inputs[:, 0, :].normal_() + outs_with_mask = attn(inputs, mask=mask) + torch.testing.assert_allclose(outs[:, 1:, :], outs_with_mask[:, 1:, :]) -def test_shift_window_msa(): - batch_size = 1 - embed_dims = 96 - input_resolution = (14, 14) - num_heads = 4 - window_size = 7 +class TestShiftWindowMSA(TestCase): - # test forward - attn = ShiftWindowMSA( - embed_dims=embed_dims, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size) - inputs = torch.rand( - (batch_size, input_resolution[0] * input_resolution[1], embed_dims)) - output = attn(inputs) - assert output.shape == (inputs.shape) - assert attn.w_msa.relative_position_bias_table.shape == ((2 * window_size - - 1)**2, num_heads) + def test_forward(self): + inputs = torch.rand((1, 14 * 14, 96)) + attn = ShiftWindowMSA(embed_dims=96, window_size=7, num_heads=4) + output = attn(inputs, (14, 14)) + self.assertEqual(output.shape, inputs.shape) + self.assertEqual(attn.w_msa.relative_position_bias_table.shape, + ((2 * 7 - 1)**2, 4)) - # test forward with shift_size - attn = ShiftWindowMSA( - embed_dims=embed_dims, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=1) - output = attn(inputs) - assert output.shape == (inputs.shape) + # test forward with shift_size + attn = ShiftWindowMSA( + embed_dims=96, window_size=7, num_heads=4, shift_size=3) + output = attn(inputs, (14, 14)) + assert output.shape == (inputs.shape) - # test relative_position_bias_table init - attn.init_weights() - assert abs(attn.w_msa.relative_position_bias_table).sum() > 0 + # test irregular input shape + input_resolution = (19, 18) + attn = ShiftWindowMSA(embed_dims=96, num_heads=4, window_size=7) + inputs = torch.rand((1, 19 * 18, 96)) + output = attn(inputs, input_resolution) + assert output.shape == (inputs.shape) - # test dropout_layer - attn = ShiftWindowMSA( - embed_dims=embed_dims, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - dropout_layer=dict(type='DropPath', drop_prob=0.5)) - torch.manual_seed(0) - output = attn(inputs) - assert (output == 0).all() + # test wrong input_resolution + input_resolution = (14, 14) + attn = ShiftWindowMSA(embed_dims=96, num_heads=4, window_size=7) + inputs = torch.rand((1, 14 * 14, 96)) + with pytest.raises(AssertionError): + attn(inputs, (14, 15)) - # test auto_pad - input_resolution = (19, 18) - attn = ShiftWindowMSA( - embed_dims=embed_dims, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - auto_pad=True) - assert attn.pad_r == 3 - assert attn.pad_b == 2 + def test_pad_small_map(self): + # test pad_small_map=True + inputs = torch.rand((1, 6 * 7, 96)) + attn = ShiftWindowMSA( + embed_dims=96, + window_size=7, + num_heads=4, + shift_size=3, + pad_small_map=True) + attn.get_attn_mask = MagicMock(wraps=attn.get_attn_mask) + output = attn(inputs, (6, 7)) + self.assertEqual(output.shape, inputs.shape) + attn.get_attn_mask.assert_called_once_with((7, 7), + window_size=7, + shift_size=3, + device=ANY) - # test small input_resolution - input_resolution = (5, 6) - attn = ShiftWindowMSA( - embed_dims=embed_dims, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=3, - auto_pad=True) - assert attn.window_size == 5 - assert attn.shift_size == 0 + # test pad_small_map=False + inputs = torch.rand((1, 6 * 7, 96)) + attn = ShiftWindowMSA( + embed_dims=96, + window_size=7, + num_heads=4, + shift_size=3, + pad_small_map=False) + with self.assertRaisesRegex(AssertionError, r'the window size \(7\)'): + attn(inputs, (6, 7)) + + # test pad_small_map=False, and the input size equals to window size + inputs = torch.rand((1, 7 * 7, 96)) + attn.get_attn_mask = MagicMock(wraps=attn.get_attn_mask) + output = attn(inputs, (7, 7)) + self.assertEqual(output.shape, inputs.shape) + attn.get_attn_mask.assert_called_once_with((7, 7), + window_size=7, + shift_size=0, + device=ANY) + + def test_drop_layer(self): + inputs = torch.rand((1, 14 * 14, 96)) + attn = ShiftWindowMSA( + embed_dims=96, + window_size=7, + num_heads=4, + dropout_layer=dict(type='Dropout', drop_prob=1.0)) + attn.init_weights() + # drop all attn output, output shuold be equal to proj.bias + self.assertTrue( + torch.allclose(attn(inputs, (14, 14)), torch.tensor(0.))) + + def test_deprecation(self): + # test deprecated arguments + with pytest.warns(DeprecationWarning): + ShiftWindowMSA( + embed_dims=96, + num_heads=4, + window_size=7, + input_resolution=(14, 14)) + + with pytest.warns(DeprecationWarning): + ShiftWindowMSA( + embed_dims=96, num_heads=4, window_size=7, auto_pad=True)