[Enhance] Support dynamic input shape for ViT-based algorithms. (#706)

* Move `resize_pos_embed` to `mmcls.models.utils`

* Refactor Vision Transformer

* Refactor DeiT

* Refactor MLP-Mixer

* Refactor Swin-Transformer

* Remove `indexing` arg

* Support dynamic inputs for t2t_vit

* Add copyright

* Fix bugs in swin transformer

* Add `pad_small_maps` option

* Update swin transformer

* Handle `attn_mask` in checkpoints of swin

* Imporve by comments
pull/720/head
Ma Zerun 2022-03-03 13:10:12 +08:00 committed by GitHub
parent 24ae53a4a0
commit c708770b42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1483 additions and 947 deletions

View File

@ -15,21 +15,38 @@ class DistilledVisionTransformer(VisionTransformer):
distillation through attention <https://arxiv.org/abs/2012.12877>`_
Args:
arch (str | dict): Vision Transformer architecture
Default: 'b'
img_size (int | tuple): Input image size
patch_size (int | tuple): The patch size
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
and 'deit-base'. If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'deit-base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
`with_cls_token` must be True. Defaults to True.
``with_cls_token`` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
@ -40,22 +57,31 @@ class DistilledVisionTransformer(VisionTransformer):
"""
num_extra_tokens = 2 # cls_token, dist_token
def __init__(self, *args, **kwargs):
super(DistilledVisionTransformer, self).__init__(*args, **kwargs)
def __init__(self, arch='deit-base', *args, **kwargs):
super(DistilledVisionTransformer, self).__init__(
arch=arch, *args, **kwargs)
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
patch_resolution = self.patch_embed.patches_resolution
x, patch_resolution = self.patch_embed(x)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
x = x + self.pos_embed
x = x + self.resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 2:]
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
@ -65,10 +91,16 @@ class DistilledVisionTransformer(VisionTransformer):
if i in self.out_indices:
B, _, C = x.shape
patch_token = x[:, 2:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
dist_token = x[:, 1]
if self.with_cls_token:
patch_token = x[:, 2:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
dist_token = x[:, 1]
else:
patch_token = x.reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = None
dist_token = None
if self.output_cls_token:
out = [patch_token, cls_token, dist_token]
else:

View File

@ -3,11 +3,11 @@ from typing import Sequence
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmcv.runner.base_module import BaseModule, ModuleList
from ..builder import BACKBONES
from ..utils import PatchEmbed, to_2tuple
from ..utils import to_2tuple
from .base_backbone import BaseBackbone
@ -105,10 +105,20 @@ class MlpMixer(BaseBackbone):
<https://arxiv.org/pdf/2105.01601.pdf>`_
Args:
arch (str | dict): MLP Mixer architecture
Defaults to 'b'.
img_size (int | tuple): Input image size.
patch_size (int | tuple): The patch size.
arch (str | dict): MLP Mixer architecture. If use string, choose from
'small', 'base' and 'large'. If use dict, it should have below
keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of MLP blocks.
- **tokens_mlp_dims** (int): The hidden dimensions for tokens FFNs.
- **channels_mlp_dims** (int): The The hidden dimensions for
channels FFNs.
Defaults to 'base'.
img_size (int | tuple): The input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
out_indices (Sequence | int): Output from which layer.
Defaults to -1, means the last layer.
drop_rate (float): Probability of an element to be zeroed.
@ -149,7 +159,7 @@ class MlpMixer(BaseBackbone):
}
def __init__(self,
arch='b',
arch='base',
img_size=224,
patch_size=16,
out_indices=-1,
@ -184,14 +194,16 @@ class MlpMixer(BaseBackbone):
self.img_size = to_2tuple(img_size)
_patch_cfg = dict(
img_size=img_size,
input_size=img_size,
embed_dims=self.embed_dims,
conv_cfg=dict(
type='Conv2d', kernel_size=patch_size, stride=patch_size),
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
num_patches = self.patch_embed.num_patches
self.patch_resolution = self.patch_embed.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
if isinstance(out_indices, int):
out_indices = [out_indices]
@ -232,7 +244,10 @@ class MlpMixer(BaseBackbone):
return getattr(self, self.norm1_name)
def forward(self, x):
x = self.patch_embed(x)
assert x.shape[2:] == self.img_size, \
"The MLP-Mixer doesn't support dynamic input shape. " \
f'Please input images with shape {self.img_size}'
x, _ = self.patch_embed(x)
outs = []
for i, layer in enumerate(self.layers):

View File

@ -2,17 +2,18 @@
from copy import deepcopy
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed, PatchMerging
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES
from ..utils import PatchEmbed, PatchMerging, ShiftWindowMSA
from ..utils import ShiftWindowMSA, resize_pos_embed, to_2tuple
from .base_backbone import BaseBackbone
@ -21,45 +22,41 @@ class SwinBlock(BaseModule):
Args:
embed_dims (int): Number of input channels.
input_resolution (Tuple[int, int]): The resolution of the input feature
map.
num_heads (int): Number of attention heads.
window_size (int, optional): The height and width of the window.
Defaults to 7.
shift (bool, optional): Shift the attention window or not.
window_size (int): The height and width of the window. Defaults to 7.
shift (bool): Shift the attention window or not. Defaults to False.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
ffn_ratio (float, optional): The expansion ratio of feedforward network
hidden layer channels. Defaults to 4.
drop_path (float, optional): The drop path rate after attention and
ffn. Defaults to 0.
attn_cfgs (dict, optional): The extra config of Shift Window-MSA.
attn_cfgs (dict): The extra config of Shift Window-MSA.
Defaults to empty dict.
ffn_cfgs (dict, optional): The extra config of FFN.
Defaults to empty dict.
norm_cfg (dict, optional): The config of norm layers.
Defaults to dict(type='LN').
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Defaults to False.
auto_pad (bool, optional): Auto pad the feature map to be divisible by
window_size, Defaults to False.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Default: None.
Defaults to None.
"""
def __init__(self,
embed_dims,
input_resolution,
num_heads,
window_size=7,
shift=False,
ffn_ratio=4.,
drop_path=0.,
pad_small_map=False,
attn_cfgs=dict(),
ffn_cfgs=dict(),
norm_cfg=dict(type='LN'),
with_cp=False,
auto_pad=False,
init_cfg=None):
super(SwinBlock, self).__init__(init_cfg)
@ -67,12 +64,11 @@ class SwinBlock(BaseModule):
_attn_cfgs = {
'embed_dims': embed_dims,
'input_resolution': input_resolution,
'num_heads': num_heads,
'shift_size': window_size // 2 if shift else 0,
'window_size': window_size,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'auto_pad': auto_pad,
'pad_small_map': pad_small_map,
**attn_cfgs
}
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
@ -90,12 +86,12 @@ class SwinBlock(BaseModule):
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = FFN(**_ffn_cfgs)
def forward(self, x):
def forward(self, x, hw_shape):
def _inner_forward(x):
identity = x
x = self.norm1(x)
x = self.attn(x)
x = self.attn(x, hw_shape)
x = x + identity
identity = x
@ -117,38 +113,39 @@ class SwinBlockSequence(BaseModule):
Args:
embed_dims (int): Number of input channels.
input_resolution (Tuple[int, int]): The resolution of the input feature
map.
depth (int): Number of successive swin transformer blocks.
num_heads (int): Number of attention heads.
downsample (bool, optional): Downsample the output of blocks by patch
merging. Defaults to False.
downsample_cfg (dict, optional): The extra config of the patch merging
layer. Defaults to empty dict.
drop_paths (Sequence[float] | float, optional): The drop path rate in
each block. Defaults to 0.
block_cfgs (Sequence[dict] | dict, optional): The extra config of each
block. Defaults to empty dicts.
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
window_size (int): The height and width of the window. Defaults to 7.
downsample (bool): Downsample the output of blocks by patch merging.
Defaults to False.
downsample_cfg (dict): The extra config of the patch merging layer.
Defaults to empty dict.
drop_paths (Sequence[float] | float): The drop path rate in each block.
Defaults to 0.
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
auto_pad (bool, optional): Auto pad the feature map to be divisible by
window_size, Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Default: None.
Defaults to None.
"""
def __init__(self,
embed_dims,
input_resolution,
depth,
num_heads,
window_size=7,
downsample=False,
downsample_cfg=dict(),
drop_paths=0.,
block_cfgs=dict(),
with_cp=False,
auto_pad=False,
pad_small_map=False,
init_cfg=None):
super().__init__(init_cfg)
@ -159,17 +156,16 @@ class SwinBlockSequence(BaseModule):
block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)]
self.embed_dims = embed_dims
self.input_resolution = input_resolution
self.blocks = ModuleList()
for i in range(depth):
_block_cfg = {
'embed_dims': embed_dims,
'input_resolution': input_resolution,
'num_heads': num_heads,
'window_size': window_size,
'shift': False if i % 2 == 0 else True,
'drop_path': drop_paths[i],
'with_cp': with_cp,
'auto_pad': auto_pad,
'pad_small_map': pad_small_map,
**block_cfgs[i]
}
block = SwinBlock(**_block_cfg)
@ -177,9 +173,8 @@ class SwinBlockSequence(BaseModule):
if downsample:
_downsample_cfg = {
'input_resolution': input_resolution,
'in_channels': embed_dims,
'expansion_ratio': 2,
'out_channels': 2 * embed_dims,
'norm_cfg': dict(type='LN'),
**downsample_cfg
}
@ -187,20 +182,15 @@ class SwinBlockSequence(BaseModule):
else:
self.downsample = None
def forward(self, x):
def forward(self, x, in_shape):
for block in self.blocks:
x = block(x)
x = block(x, in_shape)
if self.downsample:
x = self.downsample(x)
return x
@property
def out_resolution(self):
if self.downsample:
return self.downsample.output_resolution
x, out_shape = self.downsample(x, in_shape)
else:
return self.input_resolution
out_shape = in_shape
return x, out_shape
@property
def out_channels(self):
@ -212,7 +202,8 @@ class SwinBlockSequence(BaseModule):
@BACKBONES.register_module()
class SwinTransformer(BaseBackbone):
""" Swin Transformer
"""Swin Transformer.
A PyTorch implement of : `Swin Transformer:
Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>`_
@ -221,34 +212,47 @@ class SwinTransformer(BaseBackbone):
https://github.com/microsoft/Swin-Transformer
Args:
arch (str | dict): Swin Transformer architecture
Defaults to 'T'.
img_size (int | tuple): The size of input image.
Defaults to 224.
in_channels (int): The num of input channels.
Defaults to 3.
drop_rate (float): Dropout rate after embedding.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate.
Defaults to 0.1.
arch (str | dict): Swin Transformer architecture. If use string, choose
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **num_heads** (List[int]): The number of heads in attention
modules of each stage.
Defaults to 'tiny'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 4.
in_channels (int): The num of input channels. Defaults to 3.
window_size (int): The height and width of the window. Defaults to 7.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Defaults to False.
interpolate_mode (str): Select the interpolate mode for absolute
position embeding vector resize. Defaults to "bicubic".
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
auto_pad (bool): If True, auto pad feature map to fit window_size.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
norm_cfg (dict, optional): Config dict for normalization layer at end
of backone. Defaults to dict(type='LN')
stage_cfgs (Sequence | dict, optional): Extra config dict for each
stage. Defaults to empty dict.
patch_cfg (dict, optional): Extra config dict for patch embedding.
Defaults to empty dict.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
stage. Defaults to an empty dict.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
@ -258,8 +262,7 @@ class SwinTransformer(BaseBackbone):
>>> extra_config = dict(
>>> arch='tiny',
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
>>> 'expansion_ratio': 3}),
>>> auto_pad=True)
>>> 'expansion_ratio': 3}))
>>> self = SwinTransformer(**extra_config)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> output = self.forward(inputs)
@ -285,25 +288,29 @@ class SwinTransformer(BaseBackbone):
'num_heads': [6, 12, 24, 48]}),
} # yapf: disable
_version = 2
_version = 3
num_extra_tokens = 0
def __init__(self,
arch='T',
arch='tiny',
img_size=224,
patch_size=4,
in_channels=3,
window_size=7,
drop_rate=0.,
drop_path_rate=0.1,
out_indices=(3, ),
use_abs_pos_embed=False,
auto_pad=False,
interpolate_mode='bicubic',
with_cp=False,
frozen_stages=-1,
norm_eval=False,
pad_small_map=False,
norm_cfg=dict(type='LN'),
stage_cfgs=dict(),
patch_cfg=dict(),
init_cfg=None):
super(SwinTransformer, self).__init__(init_cfg)
super(SwinTransformer, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
@ -311,7 +318,7 @@ class SwinTransformer(BaseBackbone):
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {'embed_dims', 'depths', 'num_head'}
essential_keys = {'embed_dims', 'depths', 'num_heads'}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
@ -322,26 +329,28 @@ class SwinTransformer(BaseBackbone):
self.num_layers = len(self.depths)
self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed
self.auto_pad = auto_pad
self.interpolate_mode = interpolate_mode
self.frozen_stages = frozen_stages
self.num_extra_tokens = 0
_patch_cfg = {
'img_size': img_size,
'in_channels': in_channels,
'embed_dims': self.embed_dims,
'conv_cfg': dict(type='Conv2d', kernel_size=4, stride=4),
'norm_cfg': dict(type='LN'),
**patch_cfg
}
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
norm_cfg=dict(type='LN'),
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
self.patch_resolution = self.patch_embed.init_out_size
if self.use_abs_pos_embed:
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, num_patches, self.embed_dims))
self._register_load_state_dict_pre_hook(
self._prepare_abs_pos_embed)
self.drop_after_pos = nn.Dropout(p=drop_rate)
self.norm_eval = norm_eval
@ -354,7 +363,6 @@ class SwinTransformer(BaseBackbone):
self.stages = ModuleList()
embed_dims = [self.embed_dims]
input_resolution = patches_resolution
for i, (depth,
num_heads) in enumerate(zip(self.depths, self.num_heads)):
if isinstance(stage_cfgs, Sequence):
@ -366,11 +374,11 @@ class SwinTransformer(BaseBackbone):
'embed_dims': embed_dims[-1],
'depth': depth,
'num_heads': num_heads,
'window_size': window_size,
'downsample': downsample,
'input_resolution': input_resolution,
'drop_paths': dpr[:depth],
'with_cp': with_cp,
'auto_pad': auto_pad,
'pad_small_map': pad_small_map,
**stage_cfg
}
@ -379,7 +387,6 @@ class SwinTransformer(BaseBackbone):
dpr = dpr[depth:]
embed_dims.append(stage.out_channels)
input_resolution = stage.out_resolution
for i in out_indices:
if norm_cfg is not None:
@ -401,18 +408,20 @@ class SwinTransformer(BaseBackbone):
trunc_normal_(self.absolute_pos_embed, std=0.02)
def forward(self, x):
x = self.patch_embed(x)
x, hw_shape = self.patch_embed(x)
if self.use_abs_pos_embed:
x = x + self.absolute_pos_embed
x = x + resize_pos_embed(
self.absolute_pos_embed, self.patch_resolution, hw_shape,
self.interpolate_mode, self.num_extra_tokens)
x = self.drop_after_pos(x)
outs = []
for i, stage in enumerate(self.stages):
x = stage(x)
x, hw_shape = stage(x, hw_shape)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(x)
out = out.view(-1, *stage.out_resolution,
out = out.view(-1, *hw_shape,
stage.out_channels).permute(0, 3, 1,
2).contiguous()
outs.append(out)
@ -433,6 +442,12 @@ class SwinTransformer(BaseBackbone):
convert_key = k.replace('norm.', f'norm{final_stage_num}.')
state_dict[convert_key] = state_dict[k]
del state_dict[k]
if (version is None
or version < 3) and self.__class__ is SwinTransformer:
state_dict_keys = list(state_dict.keys())
for k in state_dict_keys:
if 'attn_mask' in k:
del state_dict[k]
super()._load_from_state_dict(state_dict, prefix, local_metadata,
*args, **kwargs)
@ -461,3 +476,26 @@ class SwinTransformer(BaseBackbone):
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'absolute_pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.absolute_pos_embed.shape != ckpt_pos_embed_shape:
from mmcls.utils import get_root_logger
logger = get_root_logger()
logger.info(
'Resize the absolute_pos_embed shape from '
f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.init_out_size
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)

View File

@ -11,7 +11,7 @@ from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from ..builder import BACKBONES
from ..utils import MultiheadAttention
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
from .base_backbone import BaseBackbone
@ -173,26 +173,44 @@ class T2TModule(BaseModule):
raise NotImplementedError("Performer hasn't been implemented.")
# there are 3 soft split, stride are 4,2,2 separately
self.num_patches = (img_size // (4 * 2 * 2))**2
out_side = img_size // (4 * 2 * 2)
self.init_out_size = [out_side, out_side]
self.num_patches = out_side**2
@staticmethod
def _get_unfold_size(unfold: nn.Unfold, input_size):
h, w = input_size
kernel_size = to_2tuple(unfold.kernel_size)
stride = to_2tuple(unfold.stride)
padding = to_2tuple(unfold.padding)
dilation = to_2tuple(unfold.dilation)
h_out = (h + 2 * padding[0] - dilation[0] *
(kernel_size[0] - 1) - 1) // stride[0] + 1
w_out = (w + 2 * padding[1] - dilation[1] *
(kernel_size[1] - 1) - 1) // stride[1] + 1
return (h_out, w_out)
def forward(self, x):
# step0: soft split
hw_shape = self._get_unfold_size(self.soft_split0, x.shape[2:])
x = self.soft_split0(x).transpose(1, 2)
for step in [1, 2]:
# re-structurization/reconstruction
attn = getattr(self, f'attention{step}')
x = attn(x).transpose(1, 2)
B, C, new_HW = x.shape
x = x.reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
B, C, _ = x.shape
x = x.reshape(B, C, hw_shape[0], hw_shape[1])
# soft split
soft_split = getattr(self, f'soft_split{step}')
hw_shape = self._get_unfold_size(soft_split, hw_shape)
x = soft_split(x).transpose(1, 2)
# final tokens
x = self.project(x)
return x
return x, hw_shape
def get_sinusoid_encoding(n_position, embed_dims):
@ -231,43 +249,52 @@ class T2T_ViT(BaseBackbone):
Transformers from Scratch on ImageNet <https://arxiv.org/abs/2101.11986>`_
Args:
img_size (int): Input image size.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
in_channels (int): Number of input channels.
embed_dims (int): Embedding dimension.
t2t_cfg (dict): Extra config of Tokens-to-Token module.
Defaults to an empty dict.
drop_rate (float): Dropout rate after position embedding.
Defaults to 0.
num_layers (int): Num of transformer layers in encoder.
Defaults to 14.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
drop_rate (float): Dropout rate after position embedding.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer. Defaults to
``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
output_cls_token (bool): Whether output the cls_token.
Defaults to True.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
t2t_cfg (dict): Extra config of Tokens-to-Token module.
Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
num_extra_tokens = 1 # cls_token
def __init__(self,
img_size=224,
in_channels=3,
embed_dims=384,
t2t_cfg=dict(),
drop_rate=0.,
num_layers=14,
out_indices=-1,
layer_cfgs=dict(),
drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN'),
final_norm=True,
with_cls_token=True,
output_cls_token=True,
interpolate_mode='bicubic',
t2t_cfg=dict(),
layer_cfgs=dict(),
init_cfg=None):
super(T2T_ViT, self).__init__(init_cfg)
@ -277,30 +304,41 @@ class T2T_ViT(BaseBackbone):
in_channels=in_channels,
embed_dims=embed_dims,
**t2t_cfg)
num_patches = self.tokens_to_token.num_patches
self.patch_resolution = self.tokens_to_token.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Class token
# Set cls token
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self.num_extra_tokens = 1
# Position Embedding
sinusoid_table = get_sinusoid_encoding(num_patches + 1, embed_dims)
# Set position embedding
self.interpolate_mode = interpolate_mode
sinusoid_table = get_sinusoid_encoding(
num_patches + self.num_extra_tokens, embed_dims)
self.register_buffer('pos_embed', sinusoid_table)
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'"out_indices" must be a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = num_layers + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
assert 0 <= out_indices[i] <= num_layers, \
f'Invalid out_indices {index}'
self.out_indices = out_indices
# stochastic depth decay rule
dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)]
self.encoder = ModuleList()
for i in range(num_layers):
if isinstance(layer_cfgs, Sequence):
@ -336,17 +374,49 @@ class T2T_ViT(BaseBackbone):
trunc_normal_(self.cls_token, std=.02)
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmcls.utils import get_root_logger
logger = get_root_logger()
logger.info(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.tokens_to_token.init_out_size
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
def forward(self, x):
B = x.shape[0]
x = self.tokens_to_token(x)
num_patches = self.tokens_to_token.num_patches
patch_resolution = [int(np.sqrt(num_patches))] * 2
x, patch_resolution = self.tokens_to_token(x)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
outs = []
for i, layer in enumerate(self.encoder):
x = layer(x)
@ -356,9 +426,14 @@ class T2T_ViT(BaseBackbone):
if i in self.out_indices:
B, _, C = x.shape
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
if self.with_cls_token:
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
else:
patch_token = x.reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = None
if self.output_cls_token:
out = [patch_token, cls_token]
else:

View File

@ -4,15 +4,14 @@ from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from mmcls.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import MultiheadAttention, PatchEmbed, to_2tuple
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
from .base_backbone import BaseBackbone
@ -108,21 +107,38 @@ class VisionTransformer(BaseBackbone):
for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_
Args:
arch (str | dict): Vision Transformer architecture
Default: 'b'
img_size (int | tuple): Input image size
patch_size (int | tuple): The patch size
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
and 'deit-base'. If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
`with_cls_token` must be True. Defaults to True.
``with_cls_token`` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
@ -138,7 +154,6 @@ class VisionTransformer(BaseBackbone):
'num_layers': 8,
'num_heads': 8,
'feedforward_channels': 768 * 3,
'qkv_bias': False
}),
**dict.fromkeys(
['b', 'base'], {
@ -180,14 +195,17 @@ class VisionTransformer(BaseBackbone):
num_extra_tokens = 1 # cls_token
def __init__(self,
arch='b',
arch='base',
img_size=224,
patch_size=16,
in_channels=3,
out_indices=-1,
drop_rate=0.,
drop_path_rate=0.,
qkv_bias=True,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
with_cls_token=True,
output_cls_token=True,
interpolate_mode='bicubic',
patch_cfg=dict(),
@ -214,16 +232,23 @@ class VisionTransformer(BaseBackbone):
# Set patch embedding
_patch_cfg = dict(
img_size=img_size,
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_cfg=dict(
type='Conv2d', kernel_size=patch_size, stride=patch_size),
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
num_patches = self.patch_embed.num_patches
self.patch_resolution = self.patch_embed.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set cls token
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
@ -232,6 +257,8 @@ class VisionTransformer(BaseBackbone):
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + self.num_extra_tokens,
self.embed_dims))
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
@ -242,11 +269,12 @@ class VisionTransformer(BaseBackbone):
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_layers + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
assert 0 <= out_indices[i] <= self.num_layers, \
f'Invalid out_indices {index}'
self.out_indices = out_indices
# stochastic depth decay rule
dpr = np.linspace(0, drop_path_rate, self.arch_settings['num_layers'])
dpr = np.linspace(0, drop_path_rate, self.num_layers)
self.layers = ModuleList()
if isinstance(layer_cfgs, dict):
@ -259,7 +287,7 @@ class VisionTransformer(BaseBackbone):
arch_settings['feedforward_channels'],
drop_rate=drop_rate,
drop_path_rate=dpr[i],
qkv_bias=self.arch_settings.get('qkv_bias', True),
qkv_bias=qkv_bias,
norm_cfg=norm_cfg)
_layer_cfg.update(layer_cfgs[i])
self.layers.append(TransformerEncoderLayer(**_layer_cfg))
@ -270,8 +298,6 @@ class VisionTransformer(BaseBackbone):
norm_cfg, self.embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self._register_load_state_dict_pre_hook(self._prepare_checkpoint_hook)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@ -283,7 +309,7 @@ class VisionTransformer(BaseBackbone):
and self.init_cfg['type'] == 'Pretrained'):
trunc_normal_(self.pos_embed, std=0.02)
def _prepare_checkpoint_hook(self, state_dict, prefix, *args, **kwargs):
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return
@ -299,61 +325,38 @@ class VisionTransformer(BaseBackbone):
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.patches_resolution
pos_embed_shape = self.patch_embed.init_out_size
state_dict[name] = self.resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
@staticmethod
def resize_pos_embed(pos_embed,
src_shape,
dst_shape,
mode='bicubic',
num_extra_tokens=1):
"""Resize pos_embed weights.
Args:
pos_embed (torch.Tensor): Position embedding weights with shape
[1, L, C].
src_shape (tuple): The resolution of downsampled origin training
image.
dst_shape (tuple): The resolution of downsampled new training
image.
mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'bicubic'``
Return:
torch.Tensor: The resized pos_embed of shape [1, L_new, C]
"""
assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
_, L, C = pos_embed.shape
src_h, src_w = src_shape
assert L == src_h * src_w + num_extra_tokens
extra_tokens = pos_embed[:, :num_extra_tokens]
src_weight = pos_embed[:, num_extra_tokens:]
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
dst_weight = F.interpolate(
src_weight, size=dst_shape, align_corners=False, mode=mode)
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
return torch.cat((extra_tokens, dst_weight), dim=1)
def resize_pos_embed(*args, **kwargs):
"""Interface for backward-compatibility."""
return resize_pos_embed(*args, **kwargs)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
patch_resolution = self.patch_embed.patches_resolution
x, patch_resolution = self.patch_embed(x)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
@ -363,9 +366,14 @@ class VisionTransformer(BaseBackbone):
if i in self.out_indices:
B, _, C = x.shape
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
if self.with_cls_token:
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
else:
patch_token = x.reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = None
if self.output_cls_token:
out = [patch_token, cls_token]
else:

View File

@ -2,7 +2,7 @@
from .attention import MultiheadAttention, ShiftWindowMSA
from .augment.augments import Augments
from .channel_shuffle import channel_shuffle
from .embed import HybridEmbed, PatchEmbed, PatchMerging
from .embed import HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed
from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple
from .inverted_residual import InvertedResidual
from .make_divisible import make_divisible
@ -13,5 +13,5 @@ __all__ = [
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer',
'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'PatchEmbed',
'PatchMerging', 'HybridEmbed', 'Augments', 'ShiftWindowMSA', 'is_tracing',
'MultiheadAttention', 'ConditionalPositionEncoding'
'MultiheadAttention', 'ConditionalPositionEncoding', 'resize_pos_embed'
]

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -126,14 +128,12 @@ class ShiftWindowMSA(BaseModule):
Args:
embed_dims (int): Number of input channels.
input_resolution (Tuple[int, int]): The resolution of the input feature
map.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window.
shift_size (int, optional): The shift step of each window towards
right-bottom. If zero, act as regular window-msa. Defaults to 0.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True
Defaults to True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Defaults to None.
attn_drop (float, optional): Dropout ratio of attention weight.
@ -141,15 +141,17 @@ class ShiftWindowMSA(BaseModule):
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
dropout_layer (dict, optional): The dropout_layer used before output.
Defaults to dict(type='DropPath', drop_prob=0.).
auto_pad (bool, optional): Auto pad the feature map to be divisible by
window_size, Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Default: None.
Defaults to None.
"""
def __init__(self,
embed_dims,
input_resolution,
num_heads,
window_size,
shift_size=0,
@ -158,53 +160,134 @@ class ShiftWindowMSA(BaseModule):
attn_drop=0,
proj_drop=0,
dropout_layer=dict(type='DropPath', drop_prob=0.),
auto_pad=False,
pad_small_map=False,
input_resolution=None,
auto_pad=None,
init_cfg=None):
super().__init__(init_cfg)
self.embed_dims = embed_dims
self.input_resolution = input_resolution
if input_resolution is not None or auto_pad is not None:
warnings.warn(
'The ShiftWindowMSA in new version has supported auto padding '
'and dynamic input shape in all condition. And the argument '
'`auto_pad` and `input_resolution` have been deprecated.',
DeprecationWarning)
self.shift_size = shift_size
self.window_size = window_size
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, don't partition
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size
self.w_msa = WindowMSA(embed_dims, to_2tuple(self.window_size),
num_heads, qkv_bias, qk_scale, attn_drop,
proj_drop)
self.w_msa = WindowMSA(
embed_dims=embed_dims,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop,
)
self.drop = build_dropout(dropout_layer)
self.pad_small_map = pad_small_map
H, W = self.input_resolution
# Handle auto padding
self.auto_pad = auto_pad
if self.auto_pad:
self.pad_r = (self.window_size -
W % self.window_size) % self.window_size
self.pad_b = (self.window_size -
H % self.window_size) % self.window_size
self.H_pad = H + self.pad_b
self.W_pad = W + self.pad_r
else:
H_pad, W_pad = self.input_resolution
assert H_pad % self.window_size + W_pad % self.window_size == 0,\
f'input_resolution({self.input_resolution}) is not divisible '\
f'by window_size({self.window_size}). Please check feature '\
f'map shape or set `auto_pad=True`.'
self.H_pad, self.W_pad = H_pad, W_pad
self.pad_r, self.pad_b = 0, 0
def forward(self, query, hw_shape):
B, L, C = query.shape
H, W = hw_shape
assert L == H * W, f"The query length {L} doesn't match the input "\
f'shape ({H}, {W}).'
query = query.view(B, H, W, C)
window_size = self.window_size
shift_size = self.shift_size
if min(H, W) == window_size:
# If not pad small feature map, avoid shifting when the window size
# is equal to the size of feature map. It's to align with the
# behavior of the original implementation.
shift_size = shift_size if self.pad_small_map else 0
elif min(H, W) < window_size:
# In the original implementation, the window size will be shrunk
# to the size of feature map. The behavior is different with
# swin-transformer for downstream tasks. To support dynamic input
# shape, we don't allow this feature.
assert self.pad_small_map, \
f'The input shape ({H}, {W}) is smaller than the window ' \
f'size ({window_size}). Please set `pad_small_map=True`, or ' \
'decrease the `window_size`.'
pad_r = (window_size - W % window_size) % window_size
pad_b = (window_size - H % window_size) % window_size
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
H_pad, W_pad = query.shape[1], query.shape[2]
# cyclic shift
if shift_size > 0:
query = torch.roll(
query, shifts=(-shift_size, -shift_size), dims=(1, 2))
attn_mask = self.get_attn_mask((H_pad, W_pad),
window_size=window_size,
shift_size=shift_size,
device=query.device)
# nW*B, window_size, window_size, C
query_windows = self.window_partition(query, window_size)
# nW*B, window_size*window_size, C
query_windows = query_windows.view(-1, window_size**2, C)
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
attn_windows = self.w_msa(query_windows, mask=attn_mask)
# merge windows
attn_windows = attn_windows.view(-1, window_size, window_size, C)
# B H' W' C
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad,
window_size)
# reverse cyclic shift
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, self.H_pad, self.W_pad, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
x = torch.roll(
shifted_x, shifts=(shift_size, shift_size), dims=(1, 2))
else:
x = shifted_x
if H != H_pad or W != W_pad:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = self.drop(x)
return x
@staticmethod
def window_reverse(windows, H, W, window_size):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
@staticmethod
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size,
window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows
@staticmethod
def get_attn_mask(hw_shape, window_size, shift_size, device=None):
if shift_size > 0:
img_mask = torch.zeros(1, *hw_shape, 1, device=device)
h_slices = (slice(0, -window_size), slice(-window_size,
-shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -window_size), slice(-window_size,
-shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
@ -212,83 +295,15 @@ class ShiftWindowMSA(BaseModule):
cnt += 1
# nW, window_size, window_size, 1
mask_windows = self.window_partition(img_mask)
mask_windows = mask_windows.view(
-1, self.window_size * self.window_size)
mask_windows = ShiftWindowMSA.window_partition(
img_mask, window_size)
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-100.0)).masked_fill(
attn_mask == 0, float(0.0))
attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0)
attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0)
else:
attn_mask = None
self.register_buffer('attn_mask', attn_mask)
def forward(self, query):
H, W = self.input_resolution
B, L, C = query.shape
assert L == H * W, 'input feature has wrong size'
query = query.view(B, H, W, C)
if self.pad_r or self.pad_b:
query = F.pad(query, (0, 0, 0, self.pad_r, 0, self.pad_b))
# cyclic shift
if self.shift_size > 0:
shifted_query = torch.roll(
query,
shifts=(-self.shift_size, -self.shift_size),
dims=(1, 2))
else:
shifted_query = query
# nW*B, window_size, window_size, C
query_windows = self.window_partition(shifted_query)
# nW*B, window_size*window_size, C
query_windows = query_windows.view(-1, self.window_size**2, C)
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
attn_windows = self.w_msa(query_windows, mask=self.attn_mask)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size,
self.window_size, C)
# B H' W' C
shifted_x = self.window_reverse(attn_windows, self.H_pad, self.W_pad)
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(
shifted_x,
shifts=(self.shift_size, self.shift_size),
dims=(1, 2))
else:
x = shifted_x
if self.pad_r or self.pad_b:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = self.drop(x)
return x
def window_reverse(self, windows, H, W):
window_size = self.window_size
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
def window_partition(self, x):
B, H, W, C = x.shape
window_size = self.window_size
x = x.view(B, H // window_size, window_size, W // window_size,
window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows
return attn_mask
class MultiheadAttention(BaseModule):

View File

@ -1,12 +1,59 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner.base_module import BaseModule
from .helpers import to_2tuple
def resize_pos_embed(pos_embed,
src_shape,
dst_shape,
mode='bicubic',
num_extra_tokens=1):
"""Resize pos_embed weights.
Args:
pos_embed (torch.Tensor): Position embedding weights with shape
[1, L, C].
src_shape (tuple): The resolution of downsampled origin training
image, in format (H, W).
dst_shape (tuple): The resolution of downsampled new training
image, in format (H, W).
mode (str): Algorithm used for upsampling. Choose one from 'nearest',
'linear', 'bilinear', 'bicubic' and 'trilinear'.
Defaults to 'bicubic'.
num_extra_tokens (int): The number of extra tokens, such as cls_token.
Defaults to 1.
Returns:
torch.Tensor: The resized pos_embed of shape [1, L_new, C]
"""
if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]:
return pos_embed
assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
_, L, C = pos_embed.shape
src_h, src_w = src_shape
assert L == src_h * src_w + num_extra_tokens, \
f"The length of `pos_embed` ({L}) doesn't match the expected " \
f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \
'`img_size` argument.'
extra_tokens = pos_embed[:, :num_extra_tokens]
src_weight = pos_embed[:, num_extra_tokens:]
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
dst_weight = F.interpolate(
src_weight, size=dst_shape, align_corners=False, mode=mode)
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
return torch.cat((extra_tokens, dst_weight), dim=1)
class PatchEmbed(BaseModule):
"""Image to Patch Embedding.
@ -32,6 +79,9 @@ class PatchEmbed(BaseModule):
conv_cfg=None,
init_cfg=None):
super(PatchEmbed, self).__init__(init_cfg)
warnings.warn('The `PatchEmbed` in mmcls will be deprecated. '
'Please use `mmcv.cnn.bricks.transformer.PatchEmbed`. '
"It's more general and supports dynamic input shape")
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
@ -203,6 +253,10 @@ class PatchMerging(BaseModule):
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg)
warnings.warn('The `PatchMerging` in mmcls will be deprecated. '
'Please use `mmcv.cnn.bricks.transformer.PatchMerging`. '
"It's more general and supports dynamic input shape")
H, W = input_resolution
self.input_resolution = input_resolution
self.in_channels = in_channels

View File

@ -52,8 +52,7 @@ backbone_configs = dict(
arch='small',
drop_path_rate=0.2,
img_size=800,
out_indices=(2, 3),
auto_pad=True),
out_indices=(2, 3)),
out_channels=[384, 768]),
timm_efficientnet=dict(
backbone=dict(

View File

@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -1,43 +1,131 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import os
import tempfile
from copy import deepcopy
from unittest import TestCase
import torch
from torch.nn.modules.batchnorm import _BatchNorm
from mmcv.runner import load_checkpoint, save_checkpoint
from mmcls.models.backbones import DistilledVisionTransformer
from .utils import timm_resize_pos_embed
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
class TestDeiT(TestCase):
def setUp(self):
self.cfg = dict(
arch='deit-base', img_size=224, patch_size=16, drop_rate=0.1)
def test_deit_backbone():
cfg_ori = dict(arch='deit-b', img_size=224, patch_size=16)
def test_init_weights(self):
# test weight init cfg
cfg = deepcopy(self.cfg)
cfg['init_cfg'] = [
dict(
type='Kaiming',
layer='Conv2d',
mode='fan_in',
nonlinearity='linear')
]
model = DistilledVisionTransformer(**cfg)
ori_weight = model.patch_embed.projection.weight.clone().detach()
# The pos_embed is all zero before initialize
self.assertTrue(torch.allclose(model.dist_token, torch.tensor(0.)))
# Test structure
model = DistilledVisionTransformer(**cfg_ori)
model.init_weights()
model.train()
model.init_weights()
initialized_weight = model.patch_embed.projection.weight
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
self.assertFalse(torch.allclose(model.dist_token, torch.tensor(0.)))
assert check_norm_state(model.modules(), True)
assert model.dist_token.shape == (1, 1, 768)
assert model.pos_embed.shape == (1, model.patch_embed.num_patches + 2, 768)
# test load checkpoint
pretrain_pos_embed = model.pos_embed.clone().detach()
tmpdir = tempfile.gettempdir()
checkpoint = os.path.join(tmpdir, 'test.pth')
save_checkpoint(model, checkpoint)
cfg = deepcopy(self.cfg)
model = DistilledVisionTransformer(**cfg)
load_checkpoint(model, checkpoint, strict=True)
self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed))
# Test forward
imgs = torch.rand(1, 3, 224, 224)
outs = model(imgs)
patch_token, cls_token, dist_token = outs[0]
assert patch_token.shape == (1, 768, 14, 14)
assert cls_token.shape == (1, 768)
assert dist_token.shape == (1, 768)
# test load checkpoint with different img_size
cfg = deepcopy(self.cfg)
cfg['img_size'] = 384
model = DistilledVisionTransformer(**cfg)
load_checkpoint(model, checkpoint, strict=True)
resized_pos_embed = timm_resize_pos_embed(
pretrain_pos_embed, model.pos_embed, num_tokens=2)
self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed))
# Test multiple out_indices
model = DistilledVisionTransformer(
**cfg_ori, out_indices=(0, 1, 2, 3), output_cls_token=False)
outs = model(imgs)
for out in outs:
assert out.shape == (1, 768, 14, 14)
os.remove(checkpoint)
def test_forward(self):
imgs = torch.randn(3, 3, 224, 224)
# test with_cls_token=False
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['output_cls_token'] = True
with self.assertRaisesRegex(AssertionError, 'but got False'):
DistilledVisionTransformer(**cfg)
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['output_cls_token'] = False
model = DistilledVisionTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
# test with output_cls_token
cfg = deepcopy(self.cfg)
model = DistilledVisionTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token, dist_token = outs[-1]
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
self.assertEqual(cls_token.shape, (3, 768))
self.assertEqual(dist_token.shape, (3, 768))
# test without output_cls_token
cfg = deepcopy(self.cfg)
cfg['output_cls_token'] = False
model = DistilledVisionTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
# Test forward with multi out indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = [-3, -2, -1]
model = DistilledVisionTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 3)
for out in outs:
patch_token, cls_token, dist_token = out
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
self.assertEqual(cls_token.shape, (3, 768))
self.assertEqual(dist_token.shape, (3, 768))
# Test forward with dynamic input size
imgs1 = torch.randn(3, 3, 224, 224)
imgs2 = torch.randn(3, 3, 256, 256)
imgs3 = torch.randn(3, 3, 256, 309)
cfg = deepcopy(self.cfg)
model = DistilledVisionTransformer(**cfg)
for imgs in [imgs1, imgs2, imgs3]:
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token, dist_token = outs[-1]
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
math.ceil(imgs.shape[3] / 16))
self.assertEqual(patch_token.shape, (3, 768, *expect_feat_shape))
self.assertEqual(cls_token.shape, (3, 768))
self.assertEqual(dist_token.shape, (3, 768))

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from unittest import TestCase
import pytest
import torch
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
@ -25,51 +25,95 @@ def check_norm_state(modules, train_state):
return True
def test_mlp_mixer_backbone():
cfg_ori = dict(
arch='b',
img_size=224,
patch_size=16,
drop_rate=0.1,
init_cfg=[
class TestMLPMixer(TestCase):
def setUp(self):
self.cfg = dict(
arch='b',
img_size=224,
patch_size=16,
drop_rate=0.1,
init_cfg=[
dict(
type='Kaiming',
layer='Conv2d',
mode='fan_in',
nonlinearity='linear')
])
def test_arch(self):
# Test invalid default arch
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = 'unknown'
MlpMixer(**cfg)
# Test invalid custom arch
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'embed_dims': 24,
'num_layers': 16,
'tokens_mlp_dims': 4096
}
MlpMixer(**cfg)
# Test custom arch
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'embed_dims': 128,
'num_layers': 6,
'tokens_mlp_dims': 256,
'channels_mlp_dims': 1024
}
model = MlpMixer(**cfg)
self.assertEqual(model.embed_dims, 128)
self.assertEqual(model.num_layers, 6)
for layer in model.layers:
self.assertEqual(layer.token_mix.feedforward_channels, 256)
self.assertEqual(layer.channel_mix.feedforward_channels, 1024)
def test_init_weights(self):
# test weight init cfg
cfg = deepcopy(self.cfg)
cfg['init_cfg'] = [
dict(
type='Kaiming',
layer='Conv2d',
mode='fan_in',
nonlinearity='linear')
])
]
model = MlpMixer(**cfg)
ori_weight = model.patch_embed.projection.weight.clone().detach()
model.init_weights()
initialized_weight = model.patch_embed.projection.weight
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
with pytest.raises(AssertionError):
# test invalid arch
cfg = deepcopy(cfg_ori)
cfg['arch'] = 'unknown'
MlpMixer(**cfg)
def test_forward(self):
imgs = torch.randn(3, 3, 224, 224)
with pytest.raises(AssertionError):
# test arch without essential keys
cfg = deepcopy(cfg_ori)
cfg['arch'] = {
'num_layers': 24,
'tokens_mlp_dims': 384,
'channels_mlp_dims': 3072,
}
MlpMixer(**cfg)
# test forward with single out indices
cfg = deepcopy(self.cfg)
model = MlpMixer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
feat = outs[-1]
self.assertEqual(feat.shape, (3, 768, 196))
# Test MlpMixer base model with input size of 224
# and patch size of 16
model = MlpMixer(**cfg_ori)
model.init_weights()
model.train()
# test forward with multi out indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = [-3, -2, -1]
model = MlpMixer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 3)
for feat in outs:
self.assertEqual(feat.shape, (3, 768, 196))
assert check_norm_state(model.modules(), True)
imgs = torch.randn(3, 3, 224, 224)
feat = model(imgs)[-1]
assert feat.shape == (3, 768, 196)
# Test MlpMixer with multi out indices
cfg = deepcopy(cfg_ori)
cfg['out_indices'] = [-3, -2, -1]
model = MlpMixer(**cfg)
for out in model(imgs):
assert out.shape == (3, 768, 196)
# test with invalid input shape
imgs2 = torch.randn(3, 3, 256, 256)
cfg = deepcopy(self.cfg)
model = MlpMixer(**cfg)
with self.assertRaisesRegex(AssertionError, 'dynamic input shape.'):
model(imgs2)

View File

@ -1,16 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import os
import tempfile
from math import ceil
from copy import deepcopy
from itertools import chain
from unittest import TestCase
import numpy as np
import pytest
import torch
from mmcv.runner import load_checkpoint, save_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import SwinTransformer
from mmcls.models.backbones.swin_transformer import SwinBlock
from .utils import timm_resize_pos_embed
def check_norm_state(modules, train_state):
@ -22,215 +24,232 @@ def check_norm_state(modules, train_state):
return True
def test_assertion():
"""Test Swin Transformer backbone."""
with pytest.raises(AssertionError):
# Swin Transformer arch string should be in
SwinTransformer(arch='unknown')
class TestSwinTransformer(TestCase):
with pytest.raises(AssertionError):
# Swin Transformer arch dict should include 'embed_dims',
# 'depths' and 'num_head' keys.
SwinTransformer(arch=dict(embed_dims=96, depths=[2, 2, 18, 2]))
def setUp(self):
self.cfg = dict(
arch='b', img_size=224, patch_size=4, drop_path_rate=0.1)
def test_arch(self):
# Test invalid default arch
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = 'unknown'
SwinTransformer(**cfg)
def test_forward():
# Test tiny arch forward
model = SwinTransformer(arch='Tiny')
model.init_weights()
model.train()
# Test invalid custom arch
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'embed_dims': 96,
'num_heads': [3, 6, 12, 16],
}
SwinTransformer(**cfg)
imgs = torch.randn(1, 3, 224, 224)
output = model(imgs)
assert len(output) == 1
assert output[0].shape == (1, 768, 7, 7)
# Test custom arch
cfg = deepcopy(self.cfg)
depths = [2, 2, 4, 2]
num_heads = [6, 12, 6, 12]
cfg['arch'] = {
'embed_dims': 256,
'depths': depths,
'num_heads': num_heads
}
model = SwinTransformer(**cfg)
for i, stage in enumerate(model.stages):
self.assertEqual(stage.embed_dims, 256 * (2**i))
self.assertEqual(len(stage.blocks), depths[i])
self.assertEqual(stage.blocks[0].attn.w_msa.num_heads,
num_heads[i])
# Test small arch forward
model = SwinTransformer(arch='small')
model.init_weights()
model.train()
def test_init_weights(self):
# test weight init cfg
cfg = deepcopy(self.cfg)
cfg['use_abs_pos_embed'] = True
cfg['init_cfg'] = [
dict(
type='Kaiming',
layer='Conv2d',
mode='fan_in',
nonlinearity='linear')
]
model = SwinTransformer(**cfg)
ori_weight = model.patch_embed.projection.weight.clone().detach()
# The pos_embed is all zero before initialize
self.assertTrue(
torch.allclose(model.absolute_pos_embed, torch.tensor(0.)))
imgs = torch.randn(1, 3, 224, 224)
output = model(imgs)
assert len(output) == 1
assert output[0].shape == (1, 768, 7, 7)
model.init_weights()
initialized_weight = model.patch_embed.projection.weight
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
self.assertFalse(
torch.allclose(model.absolute_pos_embed, torch.tensor(0.)))
# Test base arch forward
model = SwinTransformer(arch='B')
model.init_weights()
model.train()
pretrain_pos_embed = model.absolute_pos_embed.clone().detach()
imgs = torch.randn(1, 3, 224, 224)
output = model(imgs)
assert len(output) == 1
assert output[0].shape == (1, 1024, 7, 7)
tmpdir = tempfile.gettempdir()
# Save v3 checkpoints
checkpoint_v2 = os.path.join(tmpdir, 'v3.pth')
save_checkpoint(model, checkpoint_v2)
# Save v1 checkpoints
setattr(model, 'norm', model.norm3)
setattr(model.stages[0].blocks[1].attn, 'attn_mask',
torch.zeros(64, 49, 49))
model._version = 1
del model.norm3
checkpoint_v1 = os.path.join(tmpdir, 'v1.pth')
save_checkpoint(model, checkpoint_v1)
# Test large arch forward
model = SwinTransformer(arch='l')
model.init_weights()
model.train()
# test load v1 checkpoint
cfg = deepcopy(self.cfg)
cfg['use_abs_pos_embed'] = True
model = SwinTransformer(**cfg)
load_checkpoint(model, checkpoint_v1, strict=True)
imgs = torch.randn(1, 3, 224, 224)
output = model(imgs)
assert len(output) == 1
assert output[0].shape == (1, 1536, 7, 7)
# test load v3 checkpoint
cfg = deepcopy(self.cfg)
cfg['use_abs_pos_embed'] = True
model = SwinTransformer(**cfg)
load_checkpoint(model, checkpoint_v2, strict=True)
# Test base arch with window_size=12, image_size=384
model = SwinTransformer(
arch='base',
img_size=384,
stage_cfgs=dict(block_cfgs=dict(window_size=12)))
model.init_weights()
model.train()
# test load v3 checkpoint with different img_size
cfg = deepcopy(self.cfg)
cfg['img_size'] = 384
cfg['use_abs_pos_embed'] = True
model = SwinTransformer(**cfg)
load_checkpoint(model, checkpoint_v2, strict=True)
resized_pos_embed = timm_resize_pos_embed(
pretrain_pos_embed, model.absolute_pos_embed, num_tokens=0)
self.assertTrue(
torch.allclose(model.absolute_pos_embed, resized_pos_embed))
imgs = torch.randn(1, 3, 384, 384)
output = model(imgs)
assert len(output) == 1
assert output[0].shape == (1, 1024, 12, 12)
os.remove(checkpoint_v1)
os.remove(checkpoint_v2)
# Test multiple output indices
imgs = torch.randn(1, 3, 224, 224)
model = SwinTransformer(arch='base', out_indices=(0, 1, 2, 3))
outs = model(imgs)
assert outs[0].shape == (1, 256, 28, 28)
assert outs[1].shape == (1, 512, 14, 14)
assert outs[2].shape == (1, 1024, 7, 7)
assert outs[3].shape == (1, 1024, 7, 7)
def test_forward(self):
imgs = torch.randn(3, 3, 224, 224)
# Test base arch with with checkpoint forward
model = SwinTransformer(arch='B', with_cp=True)
for m in model.modules():
if isinstance(m, SwinBlock):
assert m.with_cp
model.init_weights()
model.train()
cfg = deepcopy(self.cfg)
model = SwinTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
feat = outs[-1]
self.assertEqual(feat.shape, (3, 1024, 7, 7))
imgs = torch.randn(1, 3, 224, 224)
output = model(imgs)
assert len(output) == 1
assert output[0].shape == (1, 1024, 7, 7)
# test with window_size=12
cfg = deepcopy(self.cfg)
cfg['window_size'] = 12
model = SwinTransformer(**cfg)
outs = model(torch.randn(3, 3, 384, 384))
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
feat = outs[-1]
self.assertEqual(feat.shape, (3, 1024, 12, 12))
with self.assertRaisesRegex(AssertionError, r'the window size \(12\)'):
model(torch.randn(3, 3, 224, 224))
# test with pad_small_map=True
cfg = deepcopy(self.cfg)
cfg['window_size'] = 12
cfg['pad_small_map'] = True
model = SwinTransformer(**cfg)
outs = model(torch.randn(3, 3, 224, 224))
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
feat = outs[-1]
self.assertEqual(feat.shape, (3, 1024, 7, 7))
def test_structure():
# Test small with use_abs_pos_embed = True
model = SwinTransformer(arch='small', use_abs_pos_embed=True)
assert model.absolute_pos_embed.shape == (1, 3136, 96)
# test multiple output indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = (0, 1, 2, 3)
model = SwinTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 4)
for stride, out in zip([2, 4, 8, 8], outs):
self.assertEqual(out.shape,
(3, 128 * stride, 56 // stride, 56 // stride))
# Test small with use_abs_pos_embed = False
model = SwinTransformer(arch='small', use_abs_pos_embed=False)
assert not hasattr(model, 'absolute_pos_embed')
# test with checkpoint forward
cfg = deepcopy(self.cfg)
cfg['with_cp'] = True
model = SwinTransformer(**cfg)
for m in model.modules():
if isinstance(m, SwinBlock):
self.assertTrue(m.with_cp)
model.init_weights()
model.train()
# Test small with auto_pad = True
model = SwinTransformer(
arch='small',
auto_pad=True,
stage_cfgs=dict(
block_cfgs={'window_size': 7},
downsample_cfg={
'kernel_size': (3, 2),
}))
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
feat = outs[-1]
self.assertEqual(feat.shape, (3, 1024, 7, 7))
# stage 1
input_h = int(224 / 4 / 3)
expect_h = ceil(input_h / 7) * 7
input_w = int(224 / 4 / 2)
expect_w = ceil(input_w / 7) * 7
assert model.stages[1].blocks[0].attn.pad_b == expect_h - input_h
assert model.stages[1].blocks[0].attn.pad_r == expect_w - input_w
# test with dynamic input shape
imgs1 = torch.randn(3, 3, 224, 224)
imgs2 = torch.randn(3, 3, 256, 256)
imgs3 = torch.randn(3, 3, 256, 309)
cfg = deepcopy(self.cfg)
model = SwinTransformer(**cfg)
for imgs in [imgs1, imgs2, imgs3]:
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
feat = outs[-1]
expect_feat_shape = (math.ceil(imgs.shape[2] / 32),
math.ceil(imgs.shape[3] / 32))
self.assertEqual(feat.shape, (3, 1024, *expect_feat_shape))
# stage 2
input_h = int(224 / 4 / 3 / 3)
# input_h is smaller than window_size, shrink the window_size to input_h.
expect_h = input_h
input_w = int(224 / 4 / 2 / 2)
expect_w = ceil(input_w / input_h) * input_h
assert model.stages[2].blocks[0].attn.pad_b == expect_h - input_h
assert model.stages[2].blocks[0].attn.pad_r == expect_w - input_w
def test_structure(self):
# test drop_path_rate decay
cfg = deepcopy(self.cfg)
cfg['drop_path_rate'] = 0.2
model = SwinTransformer(**cfg)
depths = model.arch_settings['depths']
blocks = chain(*[stage.blocks for stage in model.stages])
for i, block in enumerate(blocks):
expect_prob = 0.2 / (sum(depths) - 1) * i
self.assertAlmostEqual(block.ffn.dropout_layer.drop_prob,
expect_prob)
self.assertAlmostEqual(block.attn.drop.drop_prob, expect_prob)
# stage 3
input_h = int(224 / 4 / 3 / 3 / 3)
expect_h = input_h
input_w = int(224 / 4 / 2 / 2 / 2)
expect_w = ceil(input_w / input_h) * input_h
assert model.stages[3].blocks[0].attn.pad_b == expect_h - input_h
assert model.stages[3].blocks[0].attn.pad_r == expect_w - input_w
# test Swin-Transformer with norm_eval=True
cfg = deepcopy(self.cfg)
cfg['norm_eval'] = True
cfg['norm_cfg'] = dict(type='BN')
cfg['stage_cfgs'] = dict(block_cfgs=dict(norm_cfg=dict(type='BN')))
model = SwinTransformer(**cfg)
model.init_weights()
model.train()
self.assertTrue(check_norm_state(model.modules(), False))
# Test small with auto_pad = False
with pytest.raises(AssertionError):
model = SwinTransformer(
arch='small',
auto_pad=False,
stage_cfgs=dict(
block_cfgs={'window_size': 7},
downsample_cfg={
'kernel_size': (3, 2),
}))
# test Swin-Transformer with first stage frozen.
cfg = deepcopy(self.cfg)
frozen_stages = 0
cfg['frozen_stages'] = frozen_stages
cfg['out_indices'] = (0, 1, 2, 3)
model = SwinTransformer(**cfg)
model.init_weights()
model.train()
# Test drop_path_rate decay
model = SwinTransformer(
arch='small',
drop_path_rate=0.2,
)
depths = model.arch_settings['depths']
pos = 0
for i, depth in enumerate(depths):
for j in range(depth):
block = model.stages[i].blocks[j]
expect_prob = 0.2 / (sum(depths) - 1) * pos
assert np.isclose(block.ffn.dropout_layer.drop_prob, expect_prob)
assert np.isclose(block.attn.drop.drop_prob, expect_prob)
pos += 1
# the patch_embed and first stage should not require grad.
self.assertFalse(model.patch_embed.training)
for param in model.patch_embed.parameters():
self.assertFalse(param.requires_grad)
for i in range(frozen_stages + 1):
stage = model.stages[i]
for param in stage.parameters():
self.assertFalse(param.requires_grad)
for param in model.norm0.parameters():
self.assertFalse(param.requires_grad)
# Test Swin-Transformer with norm_eval=True
model = SwinTransformer(
arch='small',
norm_eval=True,
norm_cfg=dict(type='BN'),
stage_cfgs=dict(block_cfgs=dict(norm_cfg=dict(type='BN'))),
)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), False)
# Test Swin-Transformer with first stage frozen.
frozen_stages = 0
model = SwinTransformer(
arch='small', frozen_stages=frozen_stages, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
assert model.patch_embed.training is False
for param in model.patch_embed.parameters():
assert param.requires_grad is False
for i in range(frozen_stages + 1):
stage = model.stages[i]
for param in stage.parameters():
assert param.requires_grad is False
for param in model.norm0.parameters():
assert param.requires_grad is False
# the second stage should require grad.
stage = model.stages[1]
for param in stage.parameters():
assert param.requires_grad is True
for param in model.norm1.parameters():
assert param.requires_grad is True
def test_load_checkpoint():
model = SwinTransformer(arch='tiny')
ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth')
assert model._version == 2
# test load v2 checkpoint
save_checkpoint(model, ckpt_path)
load_checkpoint(model, ckpt_path, strict=True)
# test load v1 checkpoint
setattr(model, 'norm', model.norm3)
model._version = 1
del model.norm3
save_checkpoint(model, ckpt_path)
model = SwinTransformer(arch='tiny')
load_checkpoint(model, ckpt_path, strict=True)
# the second stage should require grad.
for i in range(frozen_stages + 1, 4):
stage = model.stages[i]
for param in stage.parameters():
self.assertTrue(param.requires_grad)
norm = getattr(model, f'norm{i}')
for param in norm.parameters():
self.assertTrue(param.requires_grad)

View File

@ -1,84 +1,157 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import os
import tempfile
from copy import deepcopy
from unittest import TestCase
import pytest
import torch
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmcv.runner import load_checkpoint, save_checkpoint
from mmcls.models.backbones import T2T_ViT
from .utils import timm_resize_pos_embed
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
class TestT2TViT(TestCase):
def setUp(self):
self.cfg = dict(
img_size=224,
in_channels=3,
embed_dims=384,
t2t_cfg=dict(
token_dims=64,
use_performer=False,
),
num_layers=14,
drop_path_rate=0.1)
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_vit_backbone():
cfg_ori = dict(
img_size=224,
in_channels=3,
embed_dims=384,
t2t_cfg=dict(
token_dims=64,
use_performer=False,
),
num_layers=14,
layer_cfgs=dict(
num_heads=6,
feedforward_channels=3 * 384, # mlp_ratio = 3
),
drop_path_rate=0.1,
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
])
with pytest.raises(NotImplementedError):
# test if use performer
cfg = deepcopy(cfg_ori)
def test_structure(self):
# The performer hasn't been implemented
cfg = deepcopy(self.cfg)
cfg['t2t_cfg']['use_performer'] = True
T2T_ViT(**cfg)
with self.assertRaises(NotImplementedError):
T2T_ViT(**cfg)
# Test T2T-ViT model with input size of 224
model = T2T_ViT(**cfg_ori)
model.init_weights()
model.train()
# Test out_indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = {1: 1}
with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
T2T_ViT(**cfg)
cfg['out_indices'] = [0, 15]
with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 15'):
T2T_ViT(**cfg)
assert check_norm_state(model.modules(), True)
# Test model structure
cfg = deepcopy(self.cfg)
model = T2T_ViT(**cfg)
self.assertEqual(len(model.encoder), 14)
dpr_inc = 0.1 / (14 - 1)
dpr = 0
for layer in model.encoder:
self.assertEqual(layer.attn.embed_dims, 384)
# The default mlp_ratio is 3
self.assertEqual(layer.ffn.feedforward_channels, 384 * 3)
self.assertAlmostEqual(layer.attn.out_drop.drop_prob, dpr)
self.assertAlmostEqual(layer.ffn.dropout_layer.drop_prob, dpr)
dpr += dpr_inc
imgs = torch.randn(3, 3, 224, 224)
patch_token, cls_token = model(imgs)[-1]
assert cls_token.shape == (3, 384)
assert patch_token.shape == (3, 384, 14, 14)
def test_init_weights(self):
# test weight init cfg
cfg = deepcopy(self.cfg)
cfg['init_cfg'] = [dict(type='TruncNormal', layer='Linear', std=.02)]
model = T2T_ViT(**cfg)
ori_weight = model.tokens_to_token.project.weight.clone().detach()
# Test custom arch T2T-ViT without output cls token
cfg = deepcopy(cfg_ori)
cfg['embed_dims'] = 256
cfg['num_layers'] = 16
cfg['layer_cfgs'] = dict(num_heads=8, feedforward_channels=1024)
cfg['output_cls_token'] = False
model.init_weights()
initialized_weight = model.tokens_to_token.project.weight
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
model = T2T_ViT(**cfg)
patch_token = model(imgs)[-1]
assert patch_token.shape == (3, 256, 14, 14)
# test load checkpoint
pretrain_pos_embed = model.pos_embed.clone().detach()
tmpdir = tempfile.gettempdir()
checkpoint = os.path.join(tmpdir, 'test.pth')
save_checkpoint(model, checkpoint)
cfg = deepcopy(self.cfg)
model = T2T_ViT(**cfg)
load_checkpoint(model, checkpoint, strict=True)
self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed))
# Test T2T_ViT with multi out indices
cfg = deepcopy(cfg_ori)
cfg['out_indices'] = [-3, -2, -1]
model = T2T_ViT(**cfg)
for out in model(imgs):
assert out[0].shape == (3, 384, 14, 14)
assert out[1].shape == (3, 384)
# test load checkpoint with different img_size
cfg = deepcopy(self.cfg)
cfg['img_size'] = 384
model = T2T_ViT(**cfg)
load_checkpoint(model, checkpoint, strict=True)
resized_pos_embed = timm_resize_pos_embed(pretrain_pos_embed,
model.pos_embed)
self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed))
os.remove(checkpoint)
def test_forward(self):
imgs = torch.randn(3, 3, 224, 224)
# test with_cls_token=False
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['output_cls_token'] = True
with self.assertRaisesRegex(AssertionError, 'but got False'):
T2T_ViT(**cfg)
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['output_cls_token'] = False
model = T2T_ViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
# test with output_cls_token
cfg = deepcopy(self.cfg)
model = T2T_ViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token = outs[-1]
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
self.assertEqual(cls_token.shape, (3, 384))
# test without output_cls_token
cfg = deepcopy(self.cfg)
cfg['output_cls_token'] = False
model = T2T_ViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
# Test forward with multi out indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = [-3, -2, -1]
model = T2T_ViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 3)
for out in outs:
patch_token, cls_token = out
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
self.assertEqual(cls_token.shape, (3, 384))
# Test forward with dynamic input size
imgs1 = torch.randn(3, 3, 224, 224)
imgs2 = torch.randn(3, 3, 256, 256)
imgs3 = torch.randn(3, 3, 256, 309)
cfg = deepcopy(self.cfg)
model = T2T_ViT(**cfg)
for imgs in [imgs1, imgs2, imgs3]:
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token = outs[-1]
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
math.ceil(imgs.shape[3] / 16))
self.assertEqual(patch_token.shape, (3, 384, *expect_feat_shape))
self.assertEqual(cls_token.shape, (3, 384))

View File

@ -3,160 +3,181 @@ import math
import os
import tempfile
from copy import deepcopy
from unittest import TestCase
import pytest
import torch
import torch.nn.functional as F
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmcv.runner import load_checkpoint, save_checkpoint
from mmcls.models.backbones import VisionTransformer
from .utils import timm_resize_pos_embed
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
class TestVisionTransformer(TestCase):
def setUp(self):
self.cfg = dict(
arch='b', img_size=224, patch_size=16, drop_path_rate=0.1)
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_structure(self):
# Test invalid default arch
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = 'unknown'
VisionTransformer(**cfg)
# Test invalid custom arch
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': 4096
}
VisionTransformer(**cfg)
def test_vit_backbone():
# Test custom arch
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'embed_dims': 128,
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': 1024
}
model = VisionTransformer(**cfg)
self.assertEqual(model.embed_dims, 128)
self.assertEqual(model.num_layers, 24)
for layer in model.layers:
self.assertEqual(layer.attn.num_heads, 16)
self.assertEqual(layer.ffn.feedforward_channels, 1024)
cfg_ori = dict(
arch='b',
img_size=224,
patch_size=16,
drop_rate=0.1,
init_cfg=[
# Test out_indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = {1: 1}
with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
VisionTransformer(**cfg)
cfg['out_indices'] = [0, 13]
with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'):
VisionTransformer(**cfg)
# Test model structure
cfg = deepcopy(self.cfg)
model = VisionTransformer(**cfg)
self.assertEqual(len(model.layers), 12)
dpr_inc = 0.1 / (12 - 1)
dpr = 0
for layer in model.layers:
self.assertEqual(layer.attn.embed_dims, 768)
self.assertEqual(layer.attn.num_heads, 12)
self.assertEqual(layer.ffn.feedforward_channels, 3072)
self.assertAlmostEqual(layer.attn.out_drop.drop_prob, dpr)
self.assertAlmostEqual(layer.ffn.dropout_layer.drop_prob, dpr)
dpr += dpr_inc
def test_init_weights(self):
# test weight init cfg
cfg = deepcopy(self.cfg)
cfg['init_cfg'] = [
dict(
type='Kaiming',
layer='Conv2d',
mode='fan_in',
nonlinearity='linear')
])
]
model = VisionTransformer(**cfg)
ori_weight = model.patch_embed.projection.weight.clone().detach()
# The pos_embed is all zero before initialize
self.assertTrue(torch.allclose(model.pos_embed, torch.tensor(0.)))
with pytest.raises(AssertionError):
# test invalid arch
cfg = deepcopy(cfg_ori)
cfg['arch'] = 'unknown'
VisionTransformer(**cfg)
model.init_weights()
initialized_weight = model.patch_embed.projection.weight
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
self.assertFalse(torch.allclose(model.pos_embed, torch.tensor(0.)))
with pytest.raises(AssertionError):
# test arch without essential keys
cfg = deepcopy(cfg_ori)
cfg['arch'] = {
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': 4096
}
VisionTransformer(**cfg)
# test load checkpoint
pretrain_pos_embed = model.pos_embed.clone().detach()
tmpdir = tempfile.gettempdir()
checkpoint = os.path.join(tmpdir, 'test.pth')
save_checkpoint(model, checkpoint)
cfg = deepcopy(self.cfg)
model = VisionTransformer(**cfg)
load_checkpoint(model, checkpoint, strict=True)
self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed))
# Test ViT base model with input size of 224
# and patch size of 16
model = VisionTransformer(**cfg_ori)
model.init_weights()
model.train()
# test load checkpoint with different img_size
cfg = deepcopy(self.cfg)
cfg['img_size'] = 384
model = VisionTransformer(**cfg)
load_checkpoint(model, checkpoint, strict=True)
resized_pos_embed = timm_resize_pos_embed(pretrain_pos_embed,
model.pos_embed)
self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed))
assert check_norm_state(model.modules(), True)
os.remove(checkpoint)
imgs = torch.randn(3, 3, 224, 224)
patch_token, cls_token = model(imgs)[-1]
assert cls_token.shape == (3, 768)
assert patch_token.shape == (3, 768, 14, 14)
def test_forward(self):
imgs = torch.randn(3, 3, 224, 224)
# Test custom arch ViT without output cls token
cfg = deepcopy(cfg_ori)
cfg['arch'] = {
'embed_dims': 128,
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': 1024
}
cfg['output_cls_token'] = False
model = VisionTransformer(**cfg)
patch_token = model(imgs)[-1]
assert patch_token.shape == (3, 128, 14, 14)
# test with_cls_token=False
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['output_cls_token'] = True
with self.assertRaisesRegex(AssertionError, 'but got False'):
VisionTransformer(**cfg)
# Test ViT with multi out indices
cfg = deepcopy(cfg_ori)
cfg['out_indices'] = [-3, -2, -1]
model = VisionTransformer(**cfg)
for out in model(imgs):
assert out[0].shape == (3, 768, 14, 14)
assert out[1].shape == (3, 768)
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['output_cls_token'] = False
model = VisionTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
# test with output_cls_token
cfg = deepcopy(self.cfg)
model = VisionTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token = outs[-1]
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
self.assertEqual(cls_token.shape, (3, 768))
def timm_resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
# Timm version pos embed resize function.
# Refers to https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # noqa:E501
ntok_new = posemb_new.shape[1]
if num_tokens:
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0,
num_tokens:]
ntok_new -= num_tokens
else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
if not len(gs_new): # backwards compatibility
gs_new = [int(math.sqrt(ntok_new))] * 2
assert len(gs_new) >= 2
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
-1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(
posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
posemb_grid = posemb_grid.permute(0, 2, 3,
1).reshape(1, gs_new[0] * gs_new[1], -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
# test without output_cls_token
cfg = deepcopy(self.cfg)
cfg['output_cls_token'] = False
model = VisionTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
# Test forward with multi out indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = [-3, -2, -1]
model = VisionTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 3)
for out in outs:
patch_token, cls_token = out
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
self.assertEqual(cls_token.shape, (3, 768))
def test_vit_weight_init():
# test weight init cfg
pretrain_cfg = dict(
arch='b',
img_size=224,
patch_size=16,
init_cfg=[dict(type='Constant', val=1., layer='Conv2d')])
pretrain_model = VisionTransformer(**pretrain_cfg)
pretrain_model.init_weights()
assert torch.allclose(pretrain_model.patch_embed.projection.weight,
torch.tensor(1.))
assert pretrain_model.pos_embed.abs().sum() > 0
pos_embed_weight = pretrain_model.pos_embed.detach()
tmpdir = tempfile.gettempdir()
checkpoint = os.path.join(tmpdir, 'test.pth')
torch.save(pretrain_model.state_dict(), checkpoint)
# test load checkpoint
finetune_cfg = dict(
arch='b',
img_size=224,
patch_size=16,
init_cfg=dict(type='Pretrained', checkpoint=checkpoint))
finetune_model = VisionTransformer(**finetune_cfg)
finetune_model.init_weights()
assert torch.allclose(finetune_model.pos_embed, pos_embed_weight)
# test load checkpoint with different img_size
finetune_cfg = dict(
arch='b',
img_size=384,
patch_size=16,
init_cfg=dict(type='Pretrained', checkpoint=checkpoint))
finetune_model = VisionTransformer(**finetune_cfg)
finetune_model.init_weights()
resized_pos_embed = timm_resize_pos_embed(pos_embed_weight,
finetune_model.pos_embed)
assert torch.allclose(finetune_model.pos_embed, resized_pos_embed)
os.remove(checkpoint)
# Test forward with dynamic input size
imgs1 = torch.randn(3, 3, 224, 224)
imgs2 = torch.randn(3, 3, 256, 256)
imgs3 = torch.randn(3, 3, 256, 309)
cfg = deepcopy(self.cfg)
model = VisionTransformer(**cfg)
for imgs in [imgs1, imgs2, imgs3]:
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token = outs[-1]
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
math.ceil(imgs.shape[3] / 16))
self.assertEqual(patch_token.shape, (3, 768, *expect_feat_shape))
self.assertEqual(cls_token.shape, (3, 768))

View File

@ -0,0 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn.functional as F
def timm_resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
"""Timm version pos embed resize function.
copied from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
""" # noqa:E501
ntok_new = posemb_new.shape[1]
if num_tokens:
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0,
num_tokens:]
ntok_new -= num_tokens
else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
if not len(gs_new): # backwards compatibility
gs_new = [int(math.sqrt(ntok_new))] * 2
assert len(gs_new) >= 2
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
-1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(
posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
posemb_grid = posemb_grid.permute(0, 2, 3,
1).reshape(1, gs_new[0] * gs_new[1], -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb

View File

@ -1,5 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from unittest import TestCase
from unittest.mock import ANY, MagicMock
import pytest
import torch
from mmcls.models.utils.attention import ShiftWindowMSA, WindowMSA
@ -22,157 +25,177 @@ def get_relative_position_index(window_size):
return relative_position_index
def test_window_msa():
batch_size = 1
num_windows = (4, 4)
embed_dims = 96
window_size = (7, 7)
num_heads = 4
attn = WindowMSA(
embed_dims=embed_dims, window_size=window_size, num_heads=num_heads)
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
window_size[0] * window_size[1], embed_dims))
class TestWindowMSA(TestCase):
# test forward
output = attn(inputs)
assert output.shape == inputs.shape
assert attn.relative_position_bias_table.shape == (
(2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
def test_forward(self):
attn = WindowMSA(embed_dims=96, window_size=(7, 7), num_heads=4)
inputs = torch.rand((16, 7 * 7, 96))
output = attn(inputs)
self.assertEqual(output.shape, inputs.shape)
# test relative_position_bias_table init
attn.init_weights()
assert abs(attn.relative_position_bias_table).sum() > 0
# test non-square window_size
attn = WindowMSA(embed_dims=96, window_size=(6, 7), num_heads=4)
inputs = torch.rand((16, 6 * 7, 96))
output = attn(inputs)
self.assertEqual(output.shape, inputs.shape)
# test non-square window_size
window_size = (6, 7)
attn = WindowMSA(
embed_dims=embed_dims, window_size=window_size, num_heads=num_heads)
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
window_size[0] * window_size[1], embed_dims))
output = attn(inputs)
assert output.shape == inputs.shape
def test_relative_pos_embed(self):
attn = WindowMSA(embed_dims=96, window_size=(7, 8), num_heads=4)
self.assertEqual(attn.relative_position_bias_table.shape,
((2 * 7 - 1) * (2 * 8 - 1), 4))
# test relative_position_index
expected_rel_pos_index = get_relative_position_index((7, 8))
self.assertTrue(
torch.allclose(attn.relative_position_index,
expected_rel_pos_index))
# test relative_position_index
expected_rel_pos_index = get_relative_position_index(window_size)
assert (attn.relative_position_index == expected_rel_pos_index).all()
# test default init
self.assertTrue(
torch.allclose(attn.relative_position_bias_table,
torch.tensor(0.)))
attn.init_weights()
self.assertFalse(
torch.allclose(attn.relative_position_bias_table,
torch.tensor(0.)))
# test qkv_bias=True
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
qkv_bias=True)
assert attn.qkv.bias.shape == (embed_dims * 3, )
def test_qkv_bias(self):
# test qkv_bias=True
attn = WindowMSA(
embed_dims=96, window_size=(7, 7), num_heads=4, qkv_bias=True)
self.assertEqual(attn.qkv.bias.shape, (96 * 3, ))
# test qkv_bias=False
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
qkv_bias=False)
assert attn.qkv.bias is None
# test qkv_bias=False
attn = WindowMSA(
embed_dims=96, window_size=(7, 7), num_heads=4, qkv_bias=False)
self.assertIsNone(attn.qkv.bias)
# test default qk_scale
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
qk_scale=None)
head_dims = embed_dims // num_heads
assert np.isclose(attn.scale, head_dims**-0.5)
def tets_qk_scale(self):
# test default qk_scale
attn = WindowMSA(
embed_dims=96, window_size=(7, 7), num_heads=4, qk_scale=None)
head_dims = 96 // 4
self.assertAlmostEqual(attn.scale, head_dims**-0.5)
# test specified qk_scale
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
qk_scale=0.3)
assert attn.scale == 0.3
# test specified qk_scale
attn = WindowMSA(
embed_dims=96, window_size=(7, 7), num_heads=4, qk_scale=0.3)
self.assertEqual(attn.scale, 0.3)
# test attn_drop
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
attn_drop=1.0)
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
window_size[0] * window_size[1], embed_dims))
# drop all attn output, output shuold be equal to proj.bias
assert torch.allclose(attn(inputs), attn.proj.bias)
def test_attn_drop(self):
inputs = torch.rand(16, 7 * 7, 96)
attn = WindowMSA(
embed_dims=96, window_size=(7, 7), num_heads=4, attn_drop=1.0)
# drop all attn output, output shuold be equal to proj.bias
self.assertTrue(torch.allclose(attn(inputs), attn.proj.bias))
# test prob_drop
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
proj_drop=1.0)
assert (attn(inputs) == 0).all()
def test_prob_drop(self):
inputs = torch.rand(16, 7 * 7, 96)
attn = WindowMSA(
embed_dims=96, window_size=(7, 7), num_heads=4, proj_drop=1.0)
self.assertTrue(torch.allclose(attn(inputs), torch.tensor(0.)))
def test_mask(self):
inputs = torch.rand(16, 7 * 7, 96)
attn = WindowMSA(embed_dims=96, window_size=(7, 7), num_heads=4)
mask = torch.zeros((4, 49, 49))
# Mask the first column
mask[:, 0, :] = -100
mask[:, :, 0] = -100
outs = attn(inputs, mask=mask)
inputs[:, 0, :].normal_()
outs_with_mask = attn(inputs, mask=mask)
torch.testing.assert_allclose(outs[:, 1:, :], outs_with_mask[:, 1:, :])
def test_shift_window_msa():
batch_size = 1
embed_dims = 96
input_resolution = (14, 14)
num_heads = 4
window_size = 7
class TestShiftWindowMSA(TestCase):
# test forward
attn = ShiftWindowMSA(
embed_dims=embed_dims,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size)
inputs = torch.rand(
(batch_size, input_resolution[0] * input_resolution[1], embed_dims))
output = attn(inputs)
assert output.shape == (inputs.shape)
assert attn.w_msa.relative_position_bias_table.shape == ((2 * window_size -
1)**2, num_heads)
def test_forward(self):
inputs = torch.rand((1, 14 * 14, 96))
attn = ShiftWindowMSA(embed_dims=96, window_size=7, num_heads=4)
output = attn(inputs, (14, 14))
self.assertEqual(output.shape, inputs.shape)
self.assertEqual(attn.w_msa.relative_position_bias_table.shape,
((2 * 7 - 1)**2, 4))
# test forward with shift_size
attn = ShiftWindowMSA(
embed_dims=embed_dims,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=1)
output = attn(inputs)
assert output.shape == (inputs.shape)
# test forward with shift_size
attn = ShiftWindowMSA(
embed_dims=96, window_size=7, num_heads=4, shift_size=3)
output = attn(inputs, (14, 14))
assert output.shape == (inputs.shape)
# test relative_position_bias_table init
attn.init_weights()
assert abs(attn.w_msa.relative_position_bias_table).sum() > 0
# test irregular input shape
input_resolution = (19, 18)
attn = ShiftWindowMSA(embed_dims=96, num_heads=4, window_size=7)
inputs = torch.rand((1, 19 * 18, 96))
output = attn(inputs, input_resolution)
assert output.shape == (inputs.shape)
# test dropout_layer
attn = ShiftWindowMSA(
embed_dims=embed_dims,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
dropout_layer=dict(type='DropPath', drop_prob=0.5))
torch.manual_seed(0)
output = attn(inputs)
assert (output == 0).all()
# test wrong input_resolution
input_resolution = (14, 14)
attn = ShiftWindowMSA(embed_dims=96, num_heads=4, window_size=7)
inputs = torch.rand((1, 14 * 14, 96))
with pytest.raises(AssertionError):
attn(inputs, (14, 15))
# test auto_pad
input_resolution = (19, 18)
attn = ShiftWindowMSA(
embed_dims=embed_dims,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
auto_pad=True)
assert attn.pad_r == 3
assert attn.pad_b == 2
def test_pad_small_map(self):
# test pad_small_map=True
inputs = torch.rand((1, 6 * 7, 96))
attn = ShiftWindowMSA(
embed_dims=96,
window_size=7,
num_heads=4,
shift_size=3,
pad_small_map=True)
attn.get_attn_mask = MagicMock(wraps=attn.get_attn_mask)
output = attn(inputs, (6, 7))
self.assertEqual(output.shape, inputs.shape)
attn.get_attn_mask.assert_called_once_with((7, 7),
window_size=7,
shift_size=3,
device=ANY)
# test small input_resolution
input_resolution = (5, 6)
attn = ShiftWindowMSA(
embed_dims=embed_dims,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=3,
auto_pad=True)
assert attn.window_size == 5
assert attn.shift_size == 0
# test pad_small_map=False
inputs = torch.rand((1, 6 * 7, 96))
attn = ShiftWindowMSA(
embed_dims=96,
window_size=7,
num_heads=4,
shift_size=3,
pad_small_map=False)
with self.assertRaisesRegex(AssertionError, r'the window size \(7\)'):
attn(inputs, (6, 7))
# test pad_small_map=False, and the input size equals to window size
inputs = torch.rand((1, 7 * 7, 96))
attn.get_attn_mask = MagicMock(wraps=attn.get_attn_mask)
output = attn(inputs, (7, 7))
self.assertEqual(output.shape, inputs.shape)
attn.get_attn_mask.assert_called_once_with((7, 7),
window_size=7,
shift_size=0,
device=ANY)
def test_drop_layer(self):
inputs = torch.rand((1, 14 * 14, 96))
attn = ShiftWindowMSA(
embed_dims=96,
window_size=7,
num_heads=4,
dropout_layer=dict(type='Dropout', drop_prob=1.0))
attn.init_weights()
# drop all attn output, output shuold be equal to proj.bias
self.assertTrue(
torch.allclose(attn(inputs, (14, 14)), torch.tensor(0.)))
def test_deprecation(self):
# test deprecated arguments
with pytest.warns(DeprecationWarning):
ShiftWindowMSA(
embed_dims=96,
num_heads=4,
window_size=7,
input_resolution=(14, 14))
with pytest.warns(DeprecationWarning):
ShiftWindowMSA(
embed_dims=96, num_heads=4, window_size=7, auto_pad=True)