[Enhance] Support dynamic input shape for ViT-based algorithms. (#706)
* Move `resize_pos_embed` to `mmcls.models.utils` * Refactor Vision Transformer * Refactor DeiT * Refactor MLP-Mixer * Refactor Swin-Transformer * Remove `indexing` arg * Support dynamic inputs for t2t_vit * Add copyright * Fix bugs in swin transformer * Add `pad_small_maps` option * Update swin transformer * Handle `attn_mask` in checkpoints of swin * Imporve by commentspull/720/head
parent
24ae53a4a0
commit
c708770b42
|
@ -15,21 +15,38 @@ class DistilledVisionTransformer(VisionTransformer):
|
|||
distillation through attention <https://arxiv.org/abs/2012.12877>`_
|
||||
|
||||
Args:
|
||||
arch (str | dict): Vision Transformer architecture
|
||||
Default: 'b'
|
||||
img_size (int | tuple): Input image size
|
||||
patch_size (int | tuple): The patch size
|
||||
arch (str | dict): Vision Transformer architecture. If use string,
|
||||
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
|
||||
and 'deit-base'. If use dict, it should have below keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **num_layers** (int): The number of transformer encoder layers.
|
||||
- **num_heads** (int): The number of heads in attention modules.
|
||||
- **feedforward_channels** (int): The hidden dimensions in
|
||||
feedforward modules.
|
||||
|
||||
Defaults to 'deit-base'.
|
||||
img_size (int | tuple): The expected input image shape. Because we
|
||||
support dynamic input shape, just set the argument to the most
|
||||
common input image shape. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 16.
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
qkv_bias (bool): Whether to add bias for qkv in attention modules.
|
||||
Defaults to True.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
`with_cls_token` must be True. Defaults to True.
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
|
@ -40,22 +57,31 @@ class DistilledVisionTransformer(VisionTransformer):
|
|||
"""
|
||||
num_extra_tokens = 2 # cls_token, dist_token
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistilledVisionTransformer, self).__init__(*args, **kwargs)
|
||||
def __init__(self, arch='deit-base', *args, **kwargs):
|
||||
super(DistilledVisionTransformer, self).__init__(
|
||||
arch=arch, *args, **kwargs)
|
||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
patch_resolution = self.patch_embed.patches_resolution
|
||||
x, patch_resolution = self.patch_embed(x)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
dist_token = self.dist_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
x = x + self.resize_pos_embed(
|
||||
self.pos_embed,
|
||||
self.patch_resolution,
|
||||
patch_resolution,
|
||||
mode=self.interpolate_mode,
|
||||
num_extra_tokens=self.num_extra_tokens)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 2:]
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
|
@ -65,10 +91,16 @@ class DistilledVisionTransformer(VisionTransformer):
|
|||
|
||||
if i in self.out_indices:
|
||||
B, _, C = x.shape
|
||||
patch_token = x[:, 2:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
dist_token = x[:, 1]
|
||||
if self.with_cls_token:
|
||||
patch_token = x[:, 2:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
dist_token = x[:, 1]
|
||||
else:
|
||||
patch_token = x.reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = None
|
||||
dist_token = None
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token, dist_token]
|
||||
else:
|
||||
|
|
|
@ -3,11 +3,11 @@ from typing import Sequence
|
|||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import PatchEmbed, to_2tuple
|
||||
from ..utils import to_2tuple
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
|
@ -105,10 +105,20 @@ class MlpMixer(BaseBackbone):
|
|||
<https://arxiv.org/pdf/2105.01601.pdf>`_
|
||||
|
||||
Args:
|
||||
arch (str | dict): MLP Mixer architecture
|
||||
Defaults to 'b'.
|
||||
img_size (int | tuple): Input image size.
|
||||
patch_size (int | tuple): The patch size.
|
||||
arch (str | dict): MLP Mixer architecture. If use string, choose from
|
||||
'small', 'base' and 'large'. If use dict, it should have below
|
||||
keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **num_layers** (int): The number of MLP blocks.
|
||||
- **tokens_mlp_dims** (int): The hidden dimensions for tokens FFNs.
|
||||
- **channels_mlp_dims** (int): The The hidden dimensions for
|
||||
channels FFNs.
|
||||
|
||||
Defaults to 'base'.
|
||||
img_size (int | tuple): The input image shape. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 16.
|
||||
out_indices (Sequence | int): Output from which layer.
|
||||
Defaults to -1, means the last layer.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
|
@ -149,7 +159,7 @@ class MlpMixer(BaseBackbone):
|
|||
}
|
||||
|
||||
def __init__(self,
|
||||
arch='b',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
out_indices=-1,
|
||||
|
@ -184,14 +194,16 @@ class MlpMixer(BaseBackbone):
|
|||
self.img_size = to_2tuple(img_size)
|
||||
|
||||
_patch_cfg = dict(
|
||||
img_size=img_size,
|
||||
input_size=img_size,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_cfg=dict(
|
||||
type='Conv2d', kernel_size=patch_size, stride=patch_size),
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
)
|
||||
_patch_cfg.update(patch_cfg)
|
||||
self.patch_embed = PatchEmbed(**_patch_cfg)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
out_indices = [out_indices]
|
||||
|
@ -232,7 +244,10 @@ class MlpMixer(BaseBackbone):
|
|||
return getattr(self, self.norm1_name)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
assert x.shape[2:] == self.img_size, \
|
||||
"The MLP-Mixer doesn't support dynamic input shape. " \
|
||||
f'Please input images with shape {self.img_size}'
|
||||
x, _ = self.patch_embed(x)
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
|
|
|
@ -2,17 +2,18 @@
|
|||
from copy import deepcopy
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed, PatchMerging
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import PatchEmbed, PatchMerging, ShiftWindowMSA
|
||||
from ..utils import ShiftWindowMSA, resize_pos_embed, to_2tuple
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
|
@ -21,45 +22,41 @@ class SwinBlock(BaseModule):
|
|||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
input_resolution (Tuple[int, int]): The resolution of the input feature
|
||||
map.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int, optional): The height and width of the window.
|
||||
Defaults to 7.
|
||||
shift (bool, optional): Shift the attention window or not.
|
||||
window_size (int): The height and width of the window. Defaults to 7.
|
||||
shift (bool): Shift the attention window or not. Defaults to False.
|
||||
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
drop_path (float): The drop path rate after attention and ffn.
|
||||
Defaults to 0.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
ffn_ratio (float, optional): The expansion ratio of feedforward network
|
||||
hidden layer channels. Defaults to 4.
|
||||
drop_path (float, optional): The drop path rate after attention and
|
||||
ffn. Defaults to 0.
|
||||
attn_cfgs (dict, optional): The extra config of Shift Window-MSA.
|
||||
attn_cfgs (dict): The extra config of Shift Window-MSA.
|
||||
Defaults to empty dict.
|
||||
ffn_cfgs (dict, optional): The extra config of FFN.
|
||||
Defaults to empty dict.
|
||||
norm_cfg (dict, optional): The config of norm layers.
|
||||
Defaults to dict(type='LN').
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Defaults to False.
|
||||
auto_pad (bool, optional): Auto pad the feature map to be divisible by
|
||||
window_size, Defaults to False.
|
||||
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
|
||||
norm_cfg (dict): The config of norm layers.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Default: None.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
shift=False,
|
||||
ffn_ratio=4.,
|
||||
drop_path=0.,
|
||||
pad_small_map=False,
|
||||
attn_cfgs=dict(),
|
||||
ffn_cfgs=dict(),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
auto_pad=False,
|
||||
init_cfg=None):
|
||||
|
||||
super(SwinBlock, self).__init__(init_cfg)
|
||||
|
@ -67,12 +64,11 @@ class SwinBlock(BaseModule):
|
|||
|
||||
_attn_cfgs = {
|
||||
'embed_dims': embed_dims,
|
||||
'input_resolution': input_resolution,
|
||||
'num_heads': num_heads,
|
||||
'shift_size': window_size // 2 if shift else 0,
|
||||
'window_size': window_size,
|
||||
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
|
||||
'auto_pad': auto_pad,
|
||||
'pad_small_map': pad_small_map,
|
||||
**attn_cfgs
|
||||
}
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
@ -90,12 +86,12 @@ class SwinBlock(BaseModule):
|
|||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.ffn = FFN(**_ffn_cfgs)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, hw_shape):
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
x = self.norm1(x)
|
||||
x = self.attn(x)
|
||||
x = self.attn(x, hw_shape)
|
||||
x = x + identity
|
||||
|
||||
identity = x
|
||||
|
@ -117,38 +113,39 @@ class SwinBlockSequence(BaseModule):
|
|||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
input_resolution (Tuple[int, int]): The resolution of the input feature
|
||||
map.
|
||||
depth (int): Number of successive swin transformer blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
downsample (bool, optional): Downsample the output of blocks by patch
|
||||
merging. Defaults to False.
|
||||
downsample_cfg (dict, optional): The extra config of the patch merging
|
||||
layer. Defaults to empty dict.
|
||||
drop_paths (Sequence[float] | float, optional): The drop path rate in
|
||||
each block. Defaults to 0.
|
||||
block_cfgs (Sequence[dict] | dict, optional): The extra config of each
|
||||
block. Defaults to empty dicts.
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
window_size (int): The height and width of the window. Defaults to 7.
|
||||
downsample (bool): Downsample the output of blocks by patch merging.
|
||||
Defaults to False.
|
||||
downsample_cfg (dict): The extra config of the patch merging layer.
|
||||
Defaults to empty dict.
|
||||
drop_paths (Sequence[float] | float): The drop path rate in each block.
|
||||
Defaults to 0.
|
||||
block_cfgs (Sequence[dict] | dict): The extra config of each block.
|
||||
Defaults to empty dicts.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
auto_pad (bool, optional): Auto pad the feature map to be divisible by
|
||||
window_size, Defaults to False.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Default: None.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
input_resolution,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
downsample=False,
|
||||
downsample_cfg=dict(),
|
||||
drop_paths=0.,
|
||||
block_cfgs=dict(),
|
||||
with_cp=False,
|
||||
auto_pad=False,
|
||||
pad_small_map=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
|
@ -159,17 +156,16 @@ class SwinBlockSequence(BaseModule):
|
|||
block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)]
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.input_resolution = input_resolution
|
||||
self.blocks = ModuleList()
|
||||
for i in range(depth):
|
||||
_block_cfg = {
|
||||
'embed_dims': embed_dims,
|
||||
'input_resolution': input_resolution,
|
||||
'num_heads': num_heads,
|
||||
'window_size': window_size,
|
||||
'shift': False if i % 2 == 0 else True,
|
||||
'drop_path': drop_paths[i],
|
||||
'with_cp': with_cp,
|
||||
'auto_pad': auto_pad,
|
||||
'pad_small_map': pad_small_map,
|
||||
**block_cfgs[i]
|
||||
}
|
||||
block = SwinBlock(**_block_cfg)
|
||||
|
@ -177,9 +173,8 @@ class SwinBlockSequence(BaseModule):
|
|||
|
||||
if downsample:
|
||||
_downsample_cfg = {
|
||||
'input_resolution': input_resolution,
|
||||
'in_channels': embed_dims,
|
||||
'expansion_ratio': 2,
|
||||
'out_channels': 2 * embed_dims,
|
||||
'norm_cfg': dict(type='LN'),
|
||||
**downsample_cfg
|
||||
}
|
||||
|
@ -187,20 +182,15 @@ class SwinBlockSequence(BaseModule):
|
|||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, in_shape):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = block(x, in_shape)
|
||||
|
||||
if self.downsample:
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def out_resolution(self):
|
||||
if self.downsample:
|
||||
return self.downsample.output_resolution
|
||||
x, out_shape = self.downsample(x, in_shape)
|
||||
else:
|
||||
return self.input_resolution
|
||||
out_shape = in_shape
|
||||
return x, out_shape
|
||||
|
||||
@property
|
||||
def out_channels(self):
|
||||
|
@ -212,7 +202,8 @@ class SwinBlockSequence(BaseModule):
|
|||
|
||||
@BACKBONES.register_module()
|
||||
class SwinTransformer(BaseBackbone):
|
||||
""" Swin Transformer
|
||||
"""Swin Transformer.
|
||||
|
||||
A PyTorch implement of : `Swin Transformer:
|
||||
Hierarchical Vision Transformer using Shifted Windows
|
||||
<https://arxiv.org/abs/2103.14030>`_
|
||||
|
@ -221,34 +212,47 @@ class SwinTransformer(BaseBackbone):
|
|||
https://github.com/microsoft/Swin-Transformer
|
||||
|
||||
Args:
|
||||
arch (str | dict): Swin Transformer architecture
|
||||
Defaults to 'T'.
|
||||
img_size (int | tuple): The size of input image.
|
||||
Defaults to 224.
|
||||
in_channels (int): The num of input channels.
|
||||
Defaults to 3.
|
||||
drop_rate (float): Dropout rate after embedding.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate.
|
||||
Defaults to 0.1.
|
||||
arch (str | dict): Swin Transformer architecture. If use string, choose
|
||||
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
|
||||
have below keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **depths** (List[int]): The number of blocks in each stage.
|
||||
- **num_heads** (List[int]): The number of heads in attention
|
||||
modules of each stage.
|
||||
|
||||
Defaults to 'tiny'.
|
||||
img_size (int | tuple): The expected input image shape. Because we
|
||||
support dynamic input shape, just set the argument to the most
|
||||
common input image shape. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 4.
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
window_size (int): The height and width of the window. Defaults to 7.
|
||||
drop_rate (float): Dropout rate after embedding. Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
|
||||
use_abs_pos_embed (bool): If True, add absolute position embedding to
|
||||
the patch embedding. Defaults to False.
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Defaults to False.
|
||||
interpolate_mode (str): Select the interpolate mode for absolute
|
||||
position embeding vector resize. Defaults to "bicubic".
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Defaults to False.
|
||||
auto_pad (bool): If True, auto pad feature map to fit window_size.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer at end
|
||||
of backone. Defaults to dict(type='LN')
|
||||
stage_cfgs (Sequence | dict, optional): Extra config dict for each
|
||||
stage. Defaults to empty dict.
|
||||
patch_cfg (dict, optional): Extra config dict for patch embedding.
|
||||
Defaults to empty dict.
|
||||
norm_cfg (dict): Config dict for normalization layer for all output
|
||||
features. Defaults to ``dict(type='LN')``
|
||||
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
|
||||
stage. Defaults to an empty dict.
|
||||
patch_cfg (dict): Extra config dict for patch embedding.
|
||||
Defaults to an empty dict.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
|
||||
|
@ -258,8 +262,7 @@ class SwinTransformer(BaseBackbone):
|
|||
>>> extra_config = dict(
|
||||
>>> arch='tiny',
|
||||
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
|
||||
>>> 'expansion_ratio': 3}),
|
||||
>>> auto_pad=True)
|
||||
>>> 'expansion_ratio': 3}))
|
||||
>>> self = SwinTransformer(**extra_config)
|
||||
>>> inputs = torch.rand(1, 3, 224, 224)
|
||||
>>> output = self.forward(inputs)
|
||||
|
@ -285,25 +288,29 @@ class SwinTransformer(BaseBackbone):
|
|||
'num_heads': [6, 12, 24, 48]}),
|
||||
} # yapf: disable
|
||||
|
||||
_version = 2
|
||||
_version = 3
|
||||
num_extra_tokens = 0
|
||||
|
||||
def __init__(self,
|
||||
arch='T',
|
||||
arch='tiny',
|
||||
img_size=224,
|
||||
patch_size=4,
|
||||
in_channels=3,
|
||||
window_size=7,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
out_indices=(3, ),
|
||||
use_abs_pos_embed=False,
|
||||
auto_pad=False,
|
||||
interpolate_mode='bicubic',
|
||||
with_cp=False,
|
||||
frozen_stages=-1,
|
||||
norm_eval=False,
|
||||
pad_small_map=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
stage_cfgs=dict(),
|
||||
patch_cfg=dict(),
|
||||
init_cfg=None):
|
||||
super(SwinTransformer, self).__init__(init_cfg)
|
||||
super(SwinTransformer, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(arch, str):
|
||||
arch = arch.lower()
|
||||
|
@ -311,7 +318,7 @@ class SwinTransformer(BaseBackbone):
|
|||
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
|
||||
self.arch_settings = self.arch_zoo[arch]
|
||||
else:
|
||||
essential_keys = {'embed_dims', 'depths', 'num_head'}
|
||||
essential_keys = {'embed_dims', 'depths', 'num_heads'}
|
||||
assert isinstance(arch, dict) and set(arch) == essential_keys, \
|
||||
f'Custom arch needs a dict with keys {essential_keys}'
|
||||
self.arch_settings = arch
|
||||
|
@ -322,26 +329,28 @@ class SwinTransformer(BaseBackbone):
|
|||
self.num_layers = len(self.depths)
|
||||
self.out_indices = out_indices
|
||||
self.use_abs_pos_embed = use_abs_pos_embed
|
||||
self.auto_pad = auto_pad
|
||||
self.interpolate_mode = interpolate_mode
|
||||
self.frozen_stages = frozen_stages
|
||||
self.num_extra_tokens = 0
|
||||
|
||||
_patch_cfg = {
|
||||
'img_size': img_size,
|
||||
'in_channels': in_channels,
|
||||
'embed_dims': self.embed_dims,
|
||||
'conv_cfg': dict(type='Conv2d', kernel_size=4, stride=4),
|
||||
'norm_cfg': dict(type='LN'),
|
||||
**patch_cfg
|
||||
}
|
||||
_patch_cfg = dict(
|
||||
in_channels=in_channels,
|
||||
input_size=img_size,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
norm_cfg=dict(type='LN'),
|
||||
)
|
||||
_patch_cfg.update(patch_cfg)
|
||||
self.patch_embed = PatchEmbed(**_patch_cfg)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
patches_resolution = self.patch_embed.patches_resolution
|
||||
self.patches_resolution = patches_resolution
|
||||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
|
||||
if self.use_abs_pos_embed:
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
self.absolute_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches, self.embed_dims))
|
||||
self._register_load_state_dict_pre_hook(
|
||||
self._prepare_abs_pos_embed)
|
||||
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
self.norm_eval = norm_eval
|
||||
|
@ -354,7 +363,6 @@ class SwinTransformer(BaseBackbone):
|
|||
|
||||
self.stages = ModuleList()
|
||||
embed_dims = [self.embed_dims]
|
||||
input_resolution = patches_resolution
|
||||
for i, (depth,
|
||||
num_heads) in enumerate(zip(self.depths, self.num_heads)):
|
||||
if isinstance(stage_cfgs, Sequence):
|
||||
|
@ -366,11 +374,11 @@ class SwinTransformer(BaseBackbone):
|
|||
'embed_dims': embed_dims[-1],
|
||||
'depth': depth,
|
||||
'num_heads': num_heads,
|
||||
'window_size': window_size,
|
||||
'downsample': downsample,
|
||||
'input_resolution': input_resolution,
|
||||
'drop_paths': dpr[:depth],
|
||||
'with_cp': with_cp,
|
||||
'auto_pad': auto_pad,
|
||||
'pad_small_map': pad_small_map,
|
||||
**stage_cfg
|
||||
}
|
||||
|
||||
|
@ -379,7 +387,6 @@ class SwinTransformer(BaseBackbone):
|
|||
|
||||
dpr = dpr[depth:]
|
||||
embed_dims.append(stage.out_channels)
|
||||
input_resolution = stage.out_resolution
|
||||
|
||||
for i in out_indices:
|
||||
if norm_cfg is not None:
|
||||
|
@ -401,18 +408,20 @@ class SwinTransformer(BaseBackbone):
|
|||
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x, hw_shape = self.patch_embed(x)
|
||||
if self.use_abs_pos_embed:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = x + resize_pos_embed(
|
||||
self.absolute_pos_embed, self.patch_resolution, hw_shape,
|
||||
self.interpolate_mode, self.num_extra_tokens)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
outs = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x = stage(x)
|
||||
x, hw_shape = stage(x, hw_shape)
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
out = norm_layer(x)
|
||||
out = out.view(-1, *stage.out_resolution,
|
||||
out = out.view(-1, *hw_shape,
|
||||
stage.out_channels).permute(0, 3, 1,
|
||||
2).contiguous()
|
||||
outs.append(out)
|
||||
|
@ -433,6 +442,12 @@ class SwinTransformer(BaseBackbone):
|
|||
convert_key = k.replace('norm.', f'norm{final_stage_num}.')
|
||||
state_dict[convert_key] = state_dict[k]
|
||||
del state_dict[k]
|
||||
if (version is None
|
||||
or version < 3) and self.__class__ is SwinTransformer:
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
for k in state_dict_keys:
|
||||
if 'attn_mask' in k:
|
||||
del state_dict[k]
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata,
|
||||
*args, **kwargs)
|
||||
|
@ -461,3 +476,26 @@ class SwinTransformer(BaseBackbone):
|
|||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs):
|
||||
name = prefix + 'absolute_pos_embed'
|
||||
if name not in state_dict.keys():
|
||||
return
|
||||
|
||||
ckpt_pos_embed_shape = state_dict[name].shape
|
||||
if self.absolute_pos_embed.shape != ckpt_pos_embed_shape:
|
||||
from mmcls.utils import get_root_logger
|
||||
logger = get_root_logger()
|
||||
logger.info(
|
||||
'Resize the absolute_pos_embed shape from '
|
||||
f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.')
|
||||
|
||||
ckpt_pos_embed_shape = to_2tuple(
|
||||
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
|
||||
pos_embed_shape = self.patch_embed.init_out_size
|
||||
|
||||
state_dict[name] = resize_pos_embed(state_dict[name],
|
||||
ckpt_pos_embed_shape,
|
||||
pos_embed_shape,
|
||||
self.interpolate_mode,
|
||||
self.num_extra_tokens)
|
||||
|
|
|
@ -11,7 +11,7 @@ from mmcv.cnn.utils.weight_init import trunc_normal_
|
|||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import MultiheadAttention
|
||||
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
|
@ -173,26 +173,44 @@ class T2TModule(BaseModule):
|
|||
raise NotImplementedError("Performer hasn't been implemented.")
|
||||
|
||||
# there are 3 soft split, stride are 4,2,2 separately
|
||||
self.num_patches = (img_size // (4 * 2 * 2))**2
|
||||
out_side = img_size // (4 * 2 * 2)
|
||||
self.init_out_size = [out_side, out_side]
|
||||
self.num_patches = out_side**2
|
||||
|
||||
@staticmethod
|
||||
def _get_unfold_size(unfold: nn.Unfold, input_size):
|
||||
h, w = input_size
|
||||
kernel_size = to_2tuple(unfold.kernel_size)
|
||||
stride = to_2tuple(unfold.stride)
|
||||
padding = to_2tuple(unfold.padding)
|
||||
dilation = to_2tuple(unfold.dilation)
|
||||
|
||||
h_out = (h + 2 * padding[0] - dilation[0] *
|
||||
(kernel_size[0] - 1) - 1) // stride[0] + 1
|
||||
w_out = (w + 2 * padding[1] - dilation[1] *
|
||||
(kernel_size[1] - 1) - 1) // stride[1] + 1
|
||||
return (h_out, w_out)
|
||||
|
||||
def forward(self, x):
|
||||
# step0: soft split
|
||||
hw_shape = self._get_unfold_size(self.soft_split0, x.shape[2:])
|
||||
x = self.soft_split0(x).transpose(1, 2)
|
||||
|
||||
for step in [1, 2]:
|
||||
# re-structurization/reconstruction
|
||||
attn = getattr(self, f'attention{step}')
|
||||
x = attn(x).transpose(1, 2)
|
||||
B, C, new_HW = x.shape
|
||||
x = x.reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
|
||||
B, C, _ = x.shape
|
||||
x = x.reshape(B, C, hw_shape[0], hw_shape[1])
|
||||
|
||||
# soft split
|
||||
soft_split = getattr(self, f'soft_split{step}')
|
||||
hw_shape = self._get_unfold_size(soft_split, hw_shape)
|
||||
x = soft_split(x).transpose(1, 2)
|
||||
|
||||
# final tokens
|
||||
x = self.project(x)
|
||||
return x
|
||||
return x, hw_shape
|
||||
|
||||
|
||||
def get_sinusoid_encoding(n_position, embed_dims):
|
||||
|
@ -231,43 +249,52 @@ class T2T_ViT(BaseBackbone):
|
|||
Transformers from Scratch on ImageNet <https://arxiv.org/abs/2101.11986>`_
|
||||
|
||||
Args:
|
||||
img_size (int): Input image size.
|
||||
img_size (int | tuple): The expected input image shape. Because we
|
||||
support dynamic input shape, just set the argument to the most
|
||||
common input image shape. Defaults to 224.
|
||||
in_channels (int): Number of input channels.
|
||||
embed_dims (int): Embedding dimension.
|
||||
t2t_cfg (dict): Extra config of Tokens-to-Token module.
|
||||
Defaults to an empty dict.
|
||||
drop_rate (float): Dropout rate after position embedding.
|
||||
Defaults to 0.
|
||||
num_layers (int): Num of transformer layers in encoder.
|
||||
Defaults to 14.
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||||
encoder. Defaults to an empty dict.
|
||||
drop_rate (float): Dropout rate after position embedding.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
norm_cfg (dict): Config dict for normalization layer. Defaults to
|
||||
``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token.
|
||||
Defaults to True.
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
t2t_cfg (dict): Extra config of Tokens-to-Token module.
|
||||
Defaults to an empty dict.
|
||||
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||||
encoder. Defaults to an empty dict.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
num_extra_tokens = 1 # cls_token
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
in_channels=3,
|
||||
embed_dims=384,
|
||||
t2t_cfg=dict(),
|
||||
drop_rate=0.,
|
||||
num_layers=14,
|
||||
out_indices=-1,
|
||||
layer_cfgs=dict(),
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN'),
|
||||
final_norm=True,
|
||||
with_cls_token=True,
|
||||
output_cls_token=True,
|
||||
interpolate_mode='bicubic',
|
||||
t2t_cfg=dict(),
|
||||
layer_cfgs=dict(),
|
||||
init_cfg=None):
|
||||
super(T2T_ViT, self).__init__(init_cfg)
|
||||
|
||||
|
@ -277,30 +304,41 @@ class T2T_ViT(BaseBackbone):
|
|||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
**t2t_cfg)
|
||||
num_patches = self.tokens_to_token.num_patches
|
||||
self.patch_resolution = self.tokens_to_token.init_out_size
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
# Class token
|
||||
# Set cls token
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
self.num_extra_tokens = 1
|
||||
|
||||
# Position Embedding
|
||||
sinusoid_table = get_sinusoid_encoding(num_patches + 1, embed_dims)
|
||||
# Set position embedding
|
||||
self.interpolate_mode = interpolate_mode
|
||||
sinusoid_table = get_sinusoid_encoding(
|
||||
num_patches + self.num_extra_tokens, embed_dims)
|
||||
self.register_buffer('pos_embed', sinusoid_table)
|
||||
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
|
||||
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
out_indices = [out_indices]
|
||||
assert isinstance(out_indices, Sequence), \
|
||||
f'"out_indices" must by a sequence or int, ' \
|
||||
f'"out_indices" must be a sequence or int, ' \
|
||||
f'get {type(out_indices)} instead.'
|
||||
for i, index in enumerate(out_indices):
|
||||
if index < 0:
|
||||
out_indices[i] = num_layers + index
|
||||
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
|
||||
assert 0 <= out_indices[i] <= num_layers, \
|
||||
f'Invalid out_indices {index}'
|
||||
self.out_indices = out_indices
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)]
|
||||
|
||||
self.encoder = ModuleList()
|
||||
for i in range(num_layers):
|
||||
if isinstance(layer_cfgs, Sequence):
|
||||
|
@ -336,17 +374,49 @@ class T2T_ViT(BaseBackbone):
|
|||
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
|
||||
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
|
||||
name = prefix + 'pos_embed'
|
||||
if name not in state_dict.keys():
|
||||
return
|
||||
|
||||
ckpt_pos_embed_shape = state_dict[name].shape
|
||||
if self.pos_embed.shape != ckpt_pos_embed_shape:
|
||||
from mmcls.utils import get_root_logger
|
||||
logger = get_root_logger()
|
||||
logger.info(
|
||||
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
|
||||
f'to {self.pos_embed.shape}.')
|
||||
|
||||
ckpt_pos_embed_shape = to_2tuple(
|
||||
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
|
||||
pos_embed_shape = self.tokens_to_token.init_out_size
|
||||
|
||||
state_dict[name] = resize_pos_embed(state_dict[name],
|
||||
ckpt_pos_embed_shape,
|
||||
pos_embed_shape,
|
||||
self.interpolate_mode,
|
||||
self.num_extra_tokens)
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.tokens_to_token(x)
|
||||
num_patches = self.tokens_to_token.num_patches
|
||||
patch_resolution = [int(np.sqrt(num_patches))] * 2
|
||||
x, patch_resolution = self.tokens_to_token(x)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
|
||||
x = x + resize_pos_embed(
|
||||
self.pos_embed,
|
||||
self.patch_resolution,
|
||||
patch_resolution,
|
||||
mode=self.interpolate_mode,
|
||||
num_extra_tokens=self.num_extra_tokens)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.encoder):
|
||||
x = layer(x)
|
||||
|
@ -356,9 +426,14 @@ class T2T_ViT(BaseBackbone):
|
|||
|
||||
if i in self.out_indices:
|
||||
B, _, C = x.shape
|
||||
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
if self.with_cls_token:
|
||||
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
else:
|
||||
patch_token = x.reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = None
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token]
|
||||
else:
|
||||
|
|
|
@ -4,15 +4,14 @@ from typing import Sequence
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
|
||||
from mmcls.utils import get_root_logger
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import MultiheadAttention, PatchEmbed, to_2tuple
|
||||
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
|
@ -108,21 +107,38 @@ class VisionTransformer(BaseBackbone):
|
|||
for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_
|
||||
|
||||
Args:
|
||||
arch (str | dict): Vision Transformer architecture
|
||||
Default: 'b'
|
||||
img_size (int | tuple): Input image size
|
||||
patch_size (int | tuple): The patch size
|
||||
arch (str | dict): Vision Transformer architecture. If use string,
|
||||
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
|
||||
and 'deit-base'. If use dict, it should have below keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **num_layers** (int): The number of transformer encoder layers.
|
||||
- **num_heads** (int): The number of heads in attention modules.
|
||||
- **feedforward_channels** (int): The hidden dimensions in
|
||||
feedforward modules.
|
||||
|
||||
Defaults to 'base'.
|
||||
img_size (int | tuple): The expected input image shape. Because we
|
||||
support dynamic input shape, just set the argument to the most
|
||||
common input image shape. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 16.
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
qkv_bias (bool): Whether to add bias for qkv in attention modules.
|
||||
Defaults to True.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
`with_cls_token` must be True. Defaults to True.
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
|
@ -138,7 +154,6 @@ class VisionTransformer(BaseBackbone):
|
|||
'num_layers': 8,
|
||||
'num_heads': 8,
|
||||
'feedforward_channels': 768 * 3,
|
||||
'qkv_bias': False
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['b', 'base'], {
|
||||
|
@ -180,14 +195,17 @@ class VisionTransformer(BaseBackbone):
|
|||
num_extra_tokens = 1 # cls_token
|
||||
|
||||
def __init__(self,
|
||||
arch='b',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
out_indices=-1,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
qkv_bias=True,
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
final_norm=True,
|
||||
with_cls_token=True,
|
||||
output_cls_token=True,
|
||||
interpolate_mode='bicubic',
|
||||
patch_cfg=dict(),
|
||||
|
@ -214,16 +232,23 @@ class VisionTransformer(BaseBackbone):
|
|||
|
||||
# Set patch embedding
|
||||
_patch_cfg = dict(
|
||||
img_size=img_size,
|
||||
in_channels=in_channels,
|
||||
input_size=img_size,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_cfg=dict(
|
||||
type='Conv2d', kernel_size=patch_size, stride=patch_size),
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
)
|
||||
_patch_cfg.update(patch_cfg)
|
||||
self.patch_embed = PatchEmbed(**_patch_cfg)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
# Set cls token
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
|
||||
|
@ -232,6 +257,8 @@ class VisionTransformer(BaseBackbone):
|
|||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + self.num_extra_tokens,
|
||||
self.embed_dims))
|
||||
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
|
||||
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
|
@ -242,11 +269,12 @@ class VisionTransformer(BaseBackbone):
|
|||
for i, index in enumerate(out_indices):
|
||||
if index < 0:
|
||||
out_indices[i] = self.num_layers + index
|
||||
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
|
||||
assert 0 <= out_indices[i] <= self.num_layers, \
|
||||
f'Invalid out_indices {index}'
|
||||
self.out_indices = out_indices
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = np.linspace(0, drop_path_rate, self.arch_settings['num_layers'])
|
||||
dpr = np.linspace(0, drop_path_rate, self.num_layers)
|
||||
|
||||
self.layers = ModuleList()
|
||||
if isinstance(layer_cfgs, dict):
|
||||
|
@ -259,7 +287,7 @@ class VisionTransformer(BaseBackbone):
|
|||
arch_settings['feedforward_channels'],
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
qkv_bias=self.arch_settings.get('qkv_bias', True),
|
||||
qkv_bias=qkv_bias,
|
||||
norm_cfg=norm_cfg)
|
||||
_layer_cfg.update(layer_cfgs[i])
|
||||
self.layers.append(TransformerEncoderLayer(**_layer_cfg))
|
||||
|
@ -270,8 +298,6 @@ class VisionTransformer(BaseBackbone):
|
|||
norm_cfg, self.embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self._prepare_checkpoint_hook)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
@ -283,7 +309,7 @@ class VisionTransformer(BaseBackbone):
|
|||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
|
||||
def _prepare_checkpoint_hook(self, state_dict, prefix, *args, **kwargs):
|
||||
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
|
||||
name = prefix + 'pos_embed'
|
||||
if name not in state_dict.keys():
|
||||
return
|
||||
|
@ -299,61 +325,38 @@ class VisionTransformer(BaseBackbone):
|
|||
|
||||
ckpt_pos_embed_shape = to_2tuple(
|
||||
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
|
||||
pos_embed_shape = self.patch_embed.patches_resolution
|
||||
pos_embed_shape = self.patch_embed.init_out_size
|
||||
|
||||
state_dict[name] = self.resize_pos_embed(state_dict[name],
|
||||
ckpt_pos_embed_shape,
|
||||
pos_embed_shape,
|
||||
self.interpolate_mode,
|
||||
self.num_extra_tokens)
|
||||
state_dict[name] = resize_pos_embed(state_dict[name],
|
||||
ckpt_pos_embed_shape,
|
||||
pos_embed_shape,
|
||||
self.interpolate_mode,
|
||||
self.num_extra_tokens)
|
||||
|
||||
@staticmethod
|
||||
def resize_pos_embed(pos_embed,
|
||||
src_shape,
|
||||
dst_shape,
|
||||
mode='bicubic',
|
||||
num_extra_tokens=1):
|
||||
"""Resize pos_embed weights.
|
||||
|
||||
Args:
|
||||
pos_embed (torch.Tensor): Position embedding weights with shape
|
||||
[1, L, C].
|
||||
src_shape (tuple): The resolution of downsampled origin training
|
||||
image.
|
||||
dst_shape (tuple): The resolution of downsampled new training
|
||||
image.
|
||||
mode (str): Algorithm used for upsampling:
|
||||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||
``'trilinear'``. Default: ``'bicubic'``
|
||||
Return:
|
||||
torch.Tensor: The resized pos_embed of shape [1, L_new, C]
|
||||
"""
|
||||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
|
||||
_, L, C = pos_embed.shape
|
||||
src_h, src_w = src_shape
|
||||
assert L == src_h * src_w + num_extra_tokens
|
||||
extra_tokens = pos_embed[:, :num_extra_tokens]
|
||||
|
||||
src_weight = pos_embed[:, num_extra_tokens:]
|
||||
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
|
||||
|
||||
dst_weight = F.interpolate(
|
||||
src_weight, size=dst_shape, align_corners=False, mode=mode)
|
||||
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
|
||||
|
||||
return torch.cat((extra_tokens, dst_weight), dim=1)
|
||||
def resize_pos_embed(*args, **kwargs):
|
||||
"""Interface for backward-compatibility."""
|
||||
return resize_pos_embed(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
patch_resolution = self.patch_embed.patches_resolution
|
||||
x, patch_resolution = self.patch_embed(x)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
x = x + resize_pos_embed(
|
||||
self.pos_embed,
|
||||
self.patch_resolution,
|
||||
patch_resolution,
|
||||
mode=self.interpolate_mode,
|
||||
num_extra_tokens=self.num_extra_tokens)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
|
@ -363,9 +366,14 @@ class VisionTransformer(BaseBackbone):
|
|||
|
||||
if i in self.out_indices:
|
||||
B, _, C = x.shape
|
||||
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
if self.with_cls_token:
|
||||
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
else:
|
||||
patch_token = x.reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = None
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token]
|
||||
else:
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from .attention import MultiheadAttention, ShiftWindowMSA
|
||||
from .augment.augments import Augments
|
||||
from .channel_shuffle import channel_shuffle
|
||||
from .embed import HybridEmbed, PatchEmbed, PatchMerging
|
||||
from .embed import HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed
|
||||
from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple
|
||||
from .inverted_residual import InvertedResidual
|
||||
from .make_divisible import make_divisible
|
||||
|
@ -13,5 +13,5 @@ __all__ = [
|
|||
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer',
|
||||
'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'PatchEmbed',
|
||||
'PatchMerging', 'HybridEmbed', 'Augments', 'ShiftWindowMSA', 'is_tracing',
|
||||
'MultiheadAttention', 'ConditionalPositionEncoding'
|
||||
'MultiheadAttention', 'ConditionalPositionEncoding', 'resize_pos_embed'
|
||||
]
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -126,14 +128,12 @@ class ShiftWindowMSA(BaseModule):
|
|||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
input_resolution (Tuple[int, int]): The resolution of the input feature
|
||||
map.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): The height and width of the window.
|
||||
shift_size (int, optional): The shift step of each window towards
|
||||
right-bottom. If zero, act as regular window-msa. Defaults to 0.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||
Default: True
|
||||
Defaults to True
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Defaults to None.
|
||||
attn_drop (float, optional): Dropout ratio of attention weight.
|
||||
|
@ -141,15 +141,17 @@ class ShiftWindowMSA(BaseModule):
|
|||
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
|
||||
dropout_layer (dict, optional): The dropout_layer used before output.
|
||||
Defaults to dict(type='DropPath', drop_prob=0.).
|
||||
auto_pad (bool, optional): Auto pad the feature map to be divisible by
|
||||
window_size, Defaults to False.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Default: None.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size,
|
||||
shift_size=0,
|
||||
|
@ -158,53 +160,134 @@ class ShiftWindowMSA(BaseModule):
|
|||
attn_drop=0,
|
||||
proj_drop=0,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=0.),
|
||||
auto_pad=False,
|
||||
pad_small_map=False,
|
||||
input_resolution=None,
|
||||
auto_pad=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.input_resolution = input_resolution
|
||||
if input_resolution is not None or auto_pad is not None:
|
||||
warnings.warn(
|
||||
'The ShiftWindowMSA in new version has supported auto padding '
|
||||
'and dynamic input shape in all condition. And the argument '
|
||||
'`auto_pad` and `input_resolution` have been deprecated.',
|
||||
DeprecationWarning)
|
||||
|
||||
self.shift_size = shift_size
|
||||
self.window_size = window_size
|
||||
if min(self.input_resolution) <= self.window_size:
|
||||
# if window size is larger than input resolution, don't partition
|
||||
self.shift_size = 0
|
||||
self.window_size = min(self.input_resolution)
|
||||
assert 0 <= self.shift_size < self.window_size
|
||||
|
||||
self.w_msa = WindowMSA(embed_dims, to_2tuple(self.window_size),
|
||||
num_heads, qkv_bias, qk_scale, attn_drop,
|
||||
proj_drop)
|
||||
self.w_msa = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
window_size=to_2tuple(self.window_size),
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
|
||||
self.drop = build_dropout(dropout_layer)
|
||||
self.pad_small_map = pad_small_map
|
||||
|
||||
H, W = self.input_resolution
|
||||
# Handle auto padding
|
||||
self.auto_pad = auto_pad
|
||||
if self.auto_pad:
|
||||
self.pad_r = (self.window_size -
|
||||
W % self.window_size) % self.window_size
|
||||
self.pad_b = (self.window_size -
|
||||
H % self.window_size) % self.window_size
|
||||
self.H_pad = H + self.pad_b
|
||||
self.W_pad = W + self.pad_r
|
||||
else:
|
||||
H_pad, W_pad = self.input_resolution
|
||||
assert H_pad % self.window_size + W_pad % self.window_size == 0,\
|
||||
f'input_resolution({self.input_resolution}) is not divisible '\
|
||||
f'by window_size({self.window_size}). Please check feature '\
|
||||
f'map shape or set `auto_pad=True`.'
|
||||
self.H_pad, self.W_pad = H_pad, W_pad
|
||||
self.pad_r, self.pad_b = 0, 0
|
||||
def forward(self, query, hw_shape):
|
||||
B, L, C = query.shape
|
||||
H, W = hw_shape
|
||||
assert L == H * W, f"The query length {L} doesn't match the input "\
|
||||
f'shape ({H}, {W}).'
|
||||
query = query.view(B, H, W, C)
|
||||
|
||||
window_size = self.window_size
|
||||
shift_size = self.shift_size
|
||||
|
||||
if min(H, W) == window_size:
|
||||
# If not pad small feature map, avoid shifting when the window size
|
||||
# is equal to the size of feature map. It's to align with the
|
||||
# behavior of the original implementation.
|
||||
shift_size = shift_size if self.pad_small_map else 0
|
||||
elif min(H, W) < window_size:
|
||||
# In the original implementation, the window size will be shrunk
|
||||
# to the size of feature map. The behavior is different with
|
||||
# swin-transformer for downstream tasks. To support dynamic input
|
||||
# shape, we don't allow this feature.
|
||||
assert self.pad_small_map, \
|
||||
f'The input shape ({H}, {W}) is smaller than the window ' \
|
||||
f'size ({window_size}). Please set `pad_small_map=True`, or ' \
|
||||
'decrease the `window_size`.'
|
||||
|
||||
pad_r = (window_size - W % window_size) % window_size
|
||||
pad_b = (window_size - H % window_size) % window_size
|
||||
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
|
||||
|
||||
H_pad, W_pad = query.shape[1], query.shape[2]
|
||||
|
||||
# cyclic shift
|
||||
if shift_size > 0:
|
||||
query = torch.roll(
|
||||
query, shifts=(-shift_size, -shift_size), dims=(1, 2))
|
||||
|
||||
attn_mask = self.get_attn_mask((H_pad, W_pad),
|
||||
window_size=window_size,
|
||||
shift_size=shift_size,
|
||||
device=query.device)
|
||||
|
||||
# nW*B, window_size, window_size, C
|
||||
query_windows = self.window_partition(query, window_size)
|
||||
# nW*B, window_size*window_size, C
|
||||
query_windows = query_windows.view(-1, window_size**2, C)
|
||||
|
||||
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
|
||||
attn_windows = self.w_msa(query_windows, mask=attn_mask)
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, window_size, window_size, C)
|
||||
|
||||
# B H' W' C
|
||||
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad,
|
||||
window_size)
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
img_mask = torch.zeros((1, self.H_pad, self.W_pad, 1)) # 1 H W 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
x = torch.roll(
|
||||
shifted_x, shifts=(shift_size, shift_size), dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
if H != H_pad or W != W_pad:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
x = self.drop(x)
|
||||
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def window_reverse(windows, H, W, window_size):
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size,
|
||||
window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def window_partition(x, window_size):
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size,
|
||||
window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
||||
windows = windows.view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
@staticmethod
|
||||
def get_attn_mask(hw_shape, window_size, shift_size, device=None):
|
||||
if shift_size > 0:
|
||||
img_mask = torch.zeros(1, *hw_shape, 1, device=device)
|
||||
h_slices = (slice(0, -window_size), slice(-window_size,
|
||||
-shift_size),
|
||||
slice(-shift_size, None))
|
||||
w_slices = (slice(0, -window_size), slice(-window_size,
|
||||
-shift_size),
|
||||
slice(-shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
|
@ -212,83 +295,15 @@ class ShiftWindowMSA(BaseModule):
|
|||
cnt += 1
|
||||
|
||||
# nW, window_size, window_size, 1
|
||||
mask_windows = self.window_partition(img_mask)
|
||||
mask_windows = mask_windows.view(
|
||||
-1, self.window_size * self.window_size)
|
||||
mask_windows = ShiftWindowMSA.window_partition(
|
||||
img_mask, window_size)
|
||||
mask_windows = mask_windows.view(-1, window_size * window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
||||
float(-100.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0))
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0)
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
self.register_buffer('attn_mask', attn_mask)
|
||||
|
||||
def forward(self, query):
|
||||
H, W = self.input_resolution
|
||||
B, L, C = query.shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
query = query.view(B, H, W, C)
|
||||
|
||||
if self.pad_r or self.pad_b:
|
||||
query = F.pad(query, (0, 0, 0, self.pad_r, 0, self.pad_b))
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_query = torch.roll(
|
||||
query,
|
||||
shifts=(-self.shift_size, -self.shift_size),
|
||||
dims=(1, 2))
|
||||
else:
|
||||
shifted_query = query
|
||||
|
||||
# nW*B, window_size, window_size, C
|
||||
query_windows = self.window_partition(shifted_query)
|
||||
# nW*B, window_size*window_size, C
|
||||
query_windows = query_windows.view(-1, self.window_size**2, C)
|
||||
|
||||
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
|
||||
attn_windows = self.w_msa(query_windows, mask=self.attn_mask)
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size,
|
||||
self.window_size, C)
|
||||
|
||||
# B H' W' C
|
||||
shifted_x = self.window_reverse(attn_windows, self.H_pad, self.W_pad)
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(
|
||||
shifted_x,
|
||||
shifts=(self.shift_size, self.shift_size),
|
||||
dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
if self.pad_r or self.pad_b:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
def window_reverse(self, windows, H, W):
|
||||
window_size = self.window_size
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size,
|
||||
window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
def window_partition(self, x):
|
||||
B, H, W, C = x.shape
|
||||
window_size = self.window_size
|
||||
x = x.view(B, H // window_size, window_size, W // window_size,
|
||||
window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
||||
windows = windows.view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
return attn_mask
|
||||
|
||||
|
||||
class MultiheadAttention(BaseModule):
|
||||
|
|
|
@ -1,12 +1,59 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
from mmcv.runner.base_module import BaseModule
|
||||
|
||||
from .helpers import to_2tuple
|
||||
|
||||
|
||||
def resize_pos_embed(pos_embed,
|
||||
src_shape,
|
||||
dst_shape,
|
||||
mode='bicubic',
|
||||
num_extra_tokens=1):
|
||||
"""Resize pos_embed weights.
|
||||
|
||||
Args:
|
||||
pos_embed (torch.Tensor): Position embedding weights with shape
|
||||
[1, L, C].
|
||||
src_shape (tuple): The resolution of downsampled origin training
|
||||
image, in format (H, W).
|
||||
dst_shape (tuple): The resolution of downsampled new training
|
||||
image, in format (H, W).
|
||||
mode (str): Algorithm used for upsampling. Choose one from 'nearest',
|
||||
'linear', 'bilinear', 'bicubic' and 'trilinear'.
|
||||
Defaults to 'bicubic'.
|
||||
num_extra_tokens (int): The number of extra tokens, such as cls_token.
|
||||
Defaults to 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The resized pos_embed of shape [1, L_new, C]
|
||||
"""
|
||||
if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]:
|
||||
return pos_embed
|
||||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
|
||||
_, L, C = pos_embed.shape
|
||||
src_h, src_w = src_shape
|
||||
assert L == src_h * src_w + num_extra_tokens, \
|
||||
f"The length of `pos_embed` ({L}) doesn't match the expected " \
|
||||
f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \
|
||||
'`img_size` argument.'
|
||||
extra_tokens = pos_embed[:, :num_extra_tokens]
|
||||
|
||||
src_weight = pos_embed[:, num_extra_tokens:]
|
||||
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
|
||||
|
||||
dst_weight = F.interpolate(
|
||||
src_weight, size=dst_shape, align_corners=False, mode=mode)
|
||||
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
|
||||
|
||||
return torch.cat((extra_tokens, dst_weight), dim=1)
|
||||
|
||||
|
||||
class PatchEmbed(BaseModule):
|
||||
"""Image to Patch Embedding.
|
||||
|
||||
|
@ -32,6 +79,9 @@ class PatchEmbed(BaseModule):
|
|||
conv_cfg=None,
|
||||
init_cfg=None):
|
||||
super(PatchEmbed, self).__init__(init_cfg)
|
||||
warnings.warn('The `PatchEmbed` in mmcls will be deprecated. '
|
||||
'Please use `mmcv.cnn.bricks.transformer.PatchEmbed`. '
|
||||
"It's more general and supports dynamic input shape")
|
||||
|
||||
if isinstance(img_size, int):
|
||||
img_size = to_2tuple(img_size)
|
||||
|
@ -203,6 +253,10 @@ class PatchMerging(BaseModule):
|
|||
norm_cfg=dict(type='LN'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
warnings.warn('The `PatchMerging` in mmcls will be deprecated. '
|
||||
'Please use `mmcv.cnn.bricks.transformer.PatchMerging`. '
|
||||
"It's more general and supports dynamic input shape")
|
||||
|
||||
H, W = input_resolution
|
||||
self.input_resolution = input_resolution
|
||||
self.in_channels = in_channels
|
||||
|
|
|
@ -52,8 +52,7 @@ backbone_configs = dict(
|
|||
arch='small',
|
||||
drop_path_rate=0.2,
|
||||
img_size=800,
|
||||
out_indices=(2, 3),
|
||||
auto_pad=True),
|
||||
out_indices=(2, 3)),
|
||||
out_channels=[384, 768]),
|
||||
timm_efficientnet=dict(
|
||||
backbone=dict(
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
@ -1,43 +1,131 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
|
||||
from mmcls.models.backbones import DistilledVisionTransformer
|
||||
from .utils import timm_resize_pos_embed
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
class TestDeiT(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = dict(
|
||||
arch='deit-base', img_size=224, patch_size=16, drop_rate=0.1)
|
||||
|
||||
def test_deit_backbone():
|
||||
cfg_ori = dict(arch='deit-b', img_size=224, patch_size=16)
|
||||
def test_init_weights(self):
|
||||
# test weight init cfg
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['init_cfg'] = [
|
||||
dict(
|
||||
type='Kaiming',
|
||||
layer='Conv2d',
|
||||
mode='fan_in',
|
||||
nonlinearity='linear')
|
||||
]
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
ori_weight = model.patch_embed.projection.weight.clone().detach()
|
||||
# The pos_embed is all zero before initialize
|
||||
self.assertTrue(torch.allclose(model.dist_token, torch.tensor(0.)))
|
||||
|
||||
# Test structure
|
||||
model = DistilledVisionTransformer(**cfg_ori)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
model.init_weights()
|
||||
initialized_weight = model.patch_embed.projection.weight
|
||||
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
|
||||
self.assertFalse(torch.allclose(model.dist_token, torch.tensor(0.)))
|
||||
|
||||
assert check_norm_state(model.modules(), True)
|
||||
assert model.dist_token.shape == (1, 1, 768)
|
||||
assert model.pos_embed.shape == (1, model.patch_embed.num_patches + 2, 768)
|
||||
# test load checkpoint
|
||||
pretrain_pos_embed = model.pos_embed.clone().detach()
|
||||
tmpdir = tempfile.gettempdir()
|
||||
checkpoint = os.path.join(tmpdir, 'test.pth')
|
||||
save_checkpoint(model, checkpoint)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint, strict=True)
|
||||
self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed))
|
||||
|
||||
# Test forward
|
||||
imgs = torch.rand(1, 3, 224, 224)
|
||||
outs = model(imgs)
|
||||
patch_token, cls_token, dist_token = outs[0]
|
||||
assert patch_token.shape == (1, 768, 14, 14)
|
||||
assert cls_token.shape == (1, 768)
|
||||
assert dist_token.shape == (1, 768)
|
||||
# test load checkpoint with different img_size
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['img_size'] = 384
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint, strict=True)
|
||||
resized_pos_embed = timm_resize_pos_embed(
|
||||
pretrain_pos_embed, model.pos_embed, num_tokens=2)
|
||||
self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed))
|
||||
|
||||
# Test multiple out_indices
|
||||
model = DistilledVisionTransformer(
|
||||
**cfg_ori, out_indices=(0, 1, 2, 3), output_cls_token=False)
|
||||
outs = model(imgs)
|
||||
for out in outs:
|
||||
assert out.shape == (1, 768, 14, 14)
|
||||
os.remove(checkpoint)
|
||||
|
||||
def test_forward(self):
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
|
||||
# test with_cls_token=False
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = True
|
||||
with self.assertRaisesRegex(AssertionError, 'but got False'):
|
||||
DistilledVisionTransformer(**cfg)
|
||||
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = False
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
|
||||
|
||||
# test with output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token, dist_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (3, 768))
|
||||
self.assertEqual(dist_token.shape, (3, 768))
|
||||
|
||||
# test without output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['output_cls_token'] = False
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
|
||||
|
||||
# Test forward with multi out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 3)
|
||||
for out in outs:
|
||||
patch_token, cls_token, dist_token = out
|
||||
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (3, 768))
|
||||
self.assertEqual(dist_token.shape, (3, 768))
|
||||
|
||||
# Test forward with dynamic input size
|
||||
imgs1 = torch.randn(3, 3, 224, 224)
|
||||
imgs2 = torch.randn(3, 3, 256, 256)
|
||||
imgs3 = torch.randn(3, 3, 256, 309)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
for imgs in [imgs1, imgs2, imgs3]:
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token, dist_token = outs[-1]
|
||||
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
|
||||
math.ceil(imgs.shape[3] / 16))
|
||||
self.assertEqual(patch_token.shape, (3, 768, *expect_feat_shape))
|
||||
self.assertEqual(cls_token.shape, (3, 768))
|
||||
self.assertEqual(dist_token.shape, (3, 768))
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from copy import deepcopy
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
@ -25,51 +25,95 @@ def check_norm_state(modules, train_state):
|
|||
return True
|
||||
|
||||
|
||||
def test_mlp_mixer_backbone():
|
||||
cfg_ori = dict(
|
||||
arch='b',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_rate=0.1,
|
||||
init_cfg=[
|
||||
class TestMLPMixer(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = dict(
|
||||
arch='b',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_rate=0.1,
|
||||
init_cfg=[
|
||||
dict(
|
||||
type='Kaiming',
|
||||
layer='Conv2d',
|
||||
mode='fan_in',
|
||||
nonlinearity='linear')
|
||||
])
|
||||
|
||||
def test_arch(self):
|
||||
# Test invalid default arch
|
||||
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = 'unknown'
|
||||
MlpMixer(**cfg)
|
||||
|
||||
# Test invalid custom arch
|
||||
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = {
|
||||
'embed_dims': 24,
|
||||
'num_layers': 16,
|
||||
'tokens_mlp_dims': 4096
|
||||
}
|
||||
MlpMixer(**cfg)
|
||||
|
||||
# Test custom arch
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = {
|
||||
'embed_dims': 128,
|
||||
'num_layers': 6,
|
||||
'tokens_mlp_dims': 256,
|
||||
'channels_mlp_dims': 1024
|
||||
}
|
||||
model = MlpMixer(**cfg)
|
||||
self.assertEqual(model.embed_dims, 128)
|
||||
self.assertEqual(model.num_layers, 6)
|
||||
for layer in model.layers:
|
||||
self.assertEqual(layer.token_mix.feedforward_channels, 256)
|
||||
self.assertEqual(layer.channel_mix.feedforward_channels, 1024)
|
||||
|
||||
def test_init_weights(self):
|
||||
# test weight init cfg
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['init_cfg'] = [
|
||||
dict(
|
||||
type='Kaiming',
|
||||
layer='Conv2d',
|
||||
mode='fan_in',
|
||||
nonlinearity='linear')
|
||||
])
|
||||
]
|
||||
model = MlpMixer(**cfg)
|
||||
ori_weight = model.patch_embed.projection.weight.clone().detach()
|
||||
model.init_weights()
|
||||
initialized_weight = model.patch_embed.projection.weight
|
||||
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# test invalid arch
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['arch'] = 'unknown'
|
||||
MlpMixer(**cfg)
|
||||
def test_forward(self):
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# test arch without essential keys
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['arch'] = {
|
||||
'num_layers': 24,
|
||||
'tokens_mlp_dims': 384,
|
||||
'channels_mlp_dims': 3072,
|
||||
}
|
||||
MlpMixer(**cfg)
|
||||
# test forward with single out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = MlpMixer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
self.assertEqual(feat.shape, (3, 768, 196))
|
||||
|
||||
# Test MlpMixer base model with input size of 224
|
||||
# and patch size of 16
|
||||
model = MlpMixer(**cfg_ori)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
# test forward with multi out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = MlpMixer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 3)
|
||||
for feat in outs:
|
||||
self.assertEqual(feat.shape, (3, 768, 196))
|
||||
|
||||
assert check_norm_state(model.modules(), True)
|
||||
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
feat = model(imgs)[-1]
|
||||
assert feat.shape == (3, 768, 196)
|
||||
|
||||
# Test MlpMixer with multi out indices
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = MlpMixer(**cfg)
|
||||
for out in model(imgs):
|
||||
assert out.shape == (3, 768, 196)
|
||||
# test with invalid input shape
|
||||
imgs2 = torch.randn(3, 3, 256, 256)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = MlpMixer(**cfg)
|
||||
with self.assertRaisesRegex(AssertionError, 'dynamic input shape.'):
|
||||
model(imgs2)
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
from math import ceil
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import SwinTransformer
|
||||
from mmcls.models.backbones.swin_transformer import SwinBlock
|
||||
from .utils import timm_resize_pos_embed
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
|
@ -22,215 +24,232 @@ def check_norm_state(modules, train_state):
|
|||
return True
|
||||
|
||||
|
||||
def test_assertion():
|
||||
"""Test Swin Transformer backbone."""
|
||||
with pytest.raises(AssertionError):
|
||||
# Swin Transformer arch string should be in
|
||||
SwinTransformer(arch='unknown')
|
||||
class TestSwinTransformer(TestCase):
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Swin Transformer arch dict should include 'embed_dims',
|
||||
# 'depths' and 'num_head' keys.
|
||||
SwinTransformer(arch=dict(embed_dims=96, depths=[2, 2, 18, 2]))
|
||||
def setUp(self):
|
||||
self.cfg = dict(
|
||||
arch='b', img_size=224, patch_size=4, drop_path_rate=0.1)
|
||||
|
||||
def test_arch(self):
|
||||
# Test invalid default arch
|
||||
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = 'unknown'
|
||||
SwinTransformer(**cfg)
|
||||
|
||||
def test_forward():
|
||||
# Test tiny arch forward
|
||||
model = SwinTransformer(arch='Tiny')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
# Test invalid custom arch
|
||||
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = {
|
||||
'embed_dims': 96,
|
||||
'num_heads': [3, 6, 12, 16],
|
||||
}
|
||||
SwinTransformer(**cfg)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
output = model(imgs)
|
||||
assert len(output) == 1
|
||||
assert output[0].shape == (1, 768, 7, 7)
|
||||
# Test custom arch
|
||||
cfg = deepcopy(self.cfg)
|
||||
depths = [2, 2, 4, 2]
|
||||
num_heads = [6, 12, 6, 12]
|
||||
cfg['arch'] = {
|
||||
'embed_dims': 256,
|
||||
'depths': depths,
|
||||
'num_heads': num_heads
|
||||
}
|
||||
model = SwinTransformer(**cfg)
|
||||
for i, stage in enumerate(model.stages):
|
||||
self.assertEqual(stage.embed_dims, 256 * (2**i))
|
||||
self.assertEqual(len(stage.blocks), depths[i])
|
||||
self.assertEqual(stage.blocks[0].attn.w_msa.num_heads,
|
||||
num_heads[i])
|
||||
|
||||
# Test small arch forward
|
||||
model = SwinTransformer(arch='small')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
def test_init_weights(self):
|
||||
# test weight init cfg
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['use_abs_pos_embed'] = True
|
||||
cfg['init_cfg'] = [
|
||||
dict(
|
||||
type='Kaiming',
|
||||
layer='Conv2d',
|
||||
mode='fan_in',
|
||||
nonlinearity='linear')
|
||||
]
|
||||
model = SwinTransformer(**cfg)
|
||||
ori_weight = model.patch_embed.projection.weight.clone().detach()
|
||||
# The pos_embed is all zero before initialize
|
||||
self.assertTrue(
|
||||
torch.allclose(model.absolute_pos_embed, torch.tensor(0.)))
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
output = model(imgs)
|
||||
assert len(output) == 1
|
||||
assert output[0].shape == (1, 768, 7, 7)
|
||||
model.init_weights()
|
||||
initialized_weight = model.patch_embed.projection.weight
|
||||
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
|
||||
self.assertFalse(
|
||||
torch.allclose(model.absolute_pos_embed, torch.tensor(0.)))
|
||||
|
||||
# Test base arch forward
|
||||
model = SwinTransformer(arch='B')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
pretrain_pos_embed = model.absolute_pos_embed.clone().detach()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
output = model(imgs)
|
||||
assert len(output) == 1
|
||||
assert output[0].shape == (1, 1024, 7, 7)
|
||||
tmpdir = tempfile.gettempdir()
|
||||
# Save v3 checkpoints
|
||||
checkpoint_v2 = os.path.join(tmpdir, 'v3.pth')
|
||||
save_checkpoint(model, checkpoint_v2)
|
||||
# Save v1 checkpoints
|
||||
setattr(model, 'norm', model.norm3)
|
||||
setattr(model.stages[0].blocks[1].attn, 'attn_mask',
|
||||
torch.zeros(64, 49, 49))
|
||||
model._version = 1
|
||||
del model.norm3
|
||||
checkpoint_v1 = os.path.join(tmpdir, 'v1.pth')
|
||||
save_checkpoint(model, checkpoint_v1)
|
||||
|
||||
# Test large arch forward
|
||||
model = SwinTransformer(arch='l')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
# test load v1 checkpoint
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['use_abs_pos_embed'] = True
|
||||
model = SwinTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint_v1, strict=True)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
output = model(imgs)
|
||||
assert len(output) == 1
|
||||
assert output[0].shape == (1, 1536, 7, 7)
|
||||
# test load v3 checkpoint
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['use_abs_pos_embed'] = True
|
||||
model = SwinTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint_v2, strict=True)
|
||||
|
||||
# Test base arch with window_size=12, image_size=384
|
||||
model = SwinTransformer(
|
||||
arch='base',
|
||||
img_size=384,
|
||||
stage_cfgs=dict(block_cfgs=dict(window_size=12)))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
# test load v3 checkpoint with different img_size
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['img_size'] = 384
|
||||
cfg['use_abs_pos_embed'] = True
|
||||
model = SwinTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint_v2, strict=True)
|
||||
resized_pos_embed = timm_resize_pos_embed(
|
||||
pretrain_pos_embed, model.absolute_pos_embed, num_tokens=0)
|
||||
self.assertTrue(
|
||||
torch.allclose(model.absolute_pos_embed, resized_pos_embed))
|
||||
|
||||
imgs = torch.randn(1, 3, 384, 384)
|
||||
output = model(imgs)
|
||||
assert len(output) == 1
|
||||
assert output[0].shape == (1, 1024, 12, 12)
|
||||
os.remove(checkpoint_v1)
|
||||
os.remove(checkpoint_v2)
|
||||
|
||||
# Test multiple output indices
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
model = SwinTransformer(arch='base', out_indices=(0, 1, 2, 3))
|
||||
outs = model(imgs)
|
||||
assert outs[0].shape == (1, 256, 28, 28)
|
||||
assert outs[1].shape == (1, 512, 14, 14)
|
||||
assert outs[2].shape == (1, 1024, 7, 7)
|
||||
assert outs[3].shape == (1, 1024, 7, 7)
|
||||
def test_forward(self):
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
|
||||
# Test base arch with with checkpoint forward
|
||||
model = SwinTransformer(arch='B', with_cp=True)
|
||||
for m in model.modules():
|
||||
if isinstance(m, SwinBlock):
|
||||
assert m.with_cp
|
||||
model.init_weights()
|
||||
model.train()
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = SwinTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
self.assertEqual(feat.shape, (3, 1024, 7, 7))
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
output = model(imgs)
|
||||
assert len(output) == 1
|
||||
assert output[0].shape == (1, 1024, 7, 7)
|
||||
# test with window_size=12
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['window_size'] = 12
|
||||
model = SwinTransformer(**cfg)
|
||||
outs = model(torch.randn(3, 3, 384, 384))
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
self.assertEqual(feat.shape, (3, 1024, 12, 12))
|
||||
with self.assertRaisesRegex(AssertionError, r'the window size \(12\)'):
|
||||
model(torch.randn(3, 3, 224, 224))
|
||||
|
||||
# test with pad_small_map=True
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['window_size'] = 12
|
||||
cfg['pad_small_map'] = True
|
||||
model = SwinTransformer(**cfg)
|
||||
outs = model(torch.randn(3, 3, 224, 224))
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
self.assertEqual(feat.shape, (3, 1024, 7, 7))
|
||||
|
||||
def test_structure():
|
||||
# Test small with use_abs_pos_embed = True
|
||||
model = SwinTransformer(arch='small', use_abs_pos_embed=True)
|
||||
assert model.absolute_pos_embed.shape == (1, 3136, 96)
|
||||
# test multiple output indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = (0, 1, 2, 3)
|
||||
model = SwinTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 4)
|
||||
for stride, out in zip([2, 4, 8, 8], outs):
|
||||
self.assertEqual(out.shape,
|
||||
(3, 128 * stride, 56 // stride, 56 // stride))
|
||||
|
||||
# Test small with use_abs_pos_embed = False
|
||||
model = SwinTransformer(arch='small', use_abs_pos_embed=False)
|
||||
assert not hasattr(model, 'absolute_pos_embed')
|
||||
# test with checkpoint forward
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cp'] = True
|
||||
model = SwinTransformer(**cfg)
|
||||
for m in model.modules():
|
||||
if isinstance(m, SwinBlock):
|
||||
self.assertTrue(m.with_cp)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
# Test small with auto_pad = True
|
||||
model = SwinTransformer(
|
||||
arch='small',
|
||||
auto_pad=True,
|
||||
stage_cfgs=dict(
|
||||
block_cfgs={'window_size': 7},
|
||||
downsample_cfg={
|
||||
'kernel_size': (3, 2),
|
||||
}))
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
self.assertEqual(feat.shape, (3, 1024, 7, 7))
|
||||
|
||||
# stage 1
|
||||
input_h = int(224 / 4 / 3)
|
||||
expect_h = ceil(input_h / 7) * 7
|
||||
input_w = int(224 / 4 / 2)
|
||||
expect_w = ceil(input_w / 7) * 7
|
||||
assert model.stages[1].blocks[0].attn.pad_b == expect_h - input_h
|
||||
assert model.stages[1].blocks[0].attn.pad_r == expect_w - input_w
|
||||
# test with dynamic input shape
|
||||
imgs1 = torch.randn(3, 3, 224, 224)
|
||||
imgs2 = torch.randn(3, 3, 256, 256)
|
||||
imgs3 = torch.randn(3, 3, 256, 309)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = SwinTransformer(**cfg)
|
||||
for imgs in [imgs1, imgs2, imgs3]:
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
expect_feat_shape = (math.ceil(imgs.shape[2] / 32),
|
||||
math.ceil(imgs.shape[3] / 32))
|
||||
self.assertEqual(feat.shape, (3, 1024, *expect_feat_shape))
|
||||
|
||||
# stage 2
|
||||
input_h = int(224 / 4 / 3 / 3)
|
||||
# input_h is smaller than window_size, shrink the window_size to input_h.
|
||||
expect_h = input_h
|
||||
input_w = int(224 / 4 / 2 / 2)
|
||||
expect_w = ceil(input_w / input_h) * input_h
|
||||
assert model.stages[2].blocks[0].attn.pad_b == expect_h - input_h
|
||||
assert model.stages[2].blocks[0].attn.pad_r == expect_w - input_w
|
||||
def test_structure(self):
|
||||
# test drop_path_rate decay
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['drop_path_rate'] = 0.2
|
||||
model = SwinTransformer(**cfg)
|
||||
depths = model.arch_settings['depths']
|
||||
blocks = chain(*[stage.blocks for stage in model.stages])
|
||||
for i, block in enumerate(blocks):
|
||||
expect_prob = 0.2 / (sum(depths) - 1) * i
|
||||
self.assertAlmostEqual(block.ffn.dropout_layer.drop_prob,
|
||||
expect_prob)
|
||||
self.assertAlmostEqual(block.attn.drop.drop_prob, expect_prob)
|
||||
|
||||
# stage 3
|
||||
input_h = int(224 / 4 / 3 / 3 / 3)
|
||||
expect_h = input_h
|
||||
input_w = int(224 / 4 / 2 / 2 / 2)
|
||||
expect_w = ceil(input_w / input_h) * input_h
|
||||
assert model.stages[3].blocks[0].attn.pad_b == expect_h - input_h
|
||||
assert model.stages[3].blocks[0].attn.pad_r == expect_w - input_w
|
||||
# test Swin-Transformer with norm_eval=True
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['norm_eval'] = True
|
||||
cfg['norm_cfg'] = dict(type='BN')
|
||||
cfg['stage_cfgs'] = dict(block_cfgs=dict(norm_cfg=dict(type='BN')))
|
||||
model = SwinTransformer(**cfg)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
self.assertTrue(check_norm_state(model.modules(), False))
|
||||
|
||||
# Test small with auto_pad = False
|
||||
with pytest.raises(AssertionError):
|
||||
model = SwinTransformer(
|
||||
arch='small',
|
||||
auto_pad=False,
|
||||
stage_cfgs=dict(
|
||||
block_cfgs={'window_size': 7},
|
||||
downsample_cfg={
|
||||
'kernel_size': (3, 2),
|
||||
}))
|
||||
# test Swin-Transformer with first stage frozen.
|
||||
cfg = deepcopy(self.cfg)
|
||||
frozen_stages = 0
|
||||
cfg['frozen_stages'] = frozen_stages
|
||||
cfg['out_indices'] = (0, 1, 2, 3)
|
||||
model = SwinTransformer(**cfg)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
# Test drop_path_rate decay
|
||||
model = SwinTransformer(
|
||||
arch='small',
|
||||
drop_path_rate=0.2,
|
||||
)
|
||||
depths = model.arch_settings['depths']
|
||||
pos = 0
|
||||
for i, depth in enumerate(depths):
|
||||
for j in range(depth):
|
||||
block = model.stages[i].blocks[j]
|
||||
expect_prob = 0.2 / (sum(depths) - 1) * pos
|
||||
assert np.isclose(block.ffn.dropout_layer.drop_prob, expect_prob)
|
||||
assert np.isclose(block.attn.drop.drop_prob, expect_prob)
|
||||
pos += 1
|
||||
# the patch_embed and first stage should not require grad.
|
||||
self.assertFalse(model.patch_embed.training)
|
||||
for param in model.patch_embed.parameters():
|
||||
self.assertFalse(param.requires_grad)
|
||||
for i in range(frozen_stages + 1):
|
||||
stage = model.stages[i]
|
||||
for param in stage.parameters():
|
||||
self.assertFalse(param.requires_grad)
|
||||
for param in model.norm0.parameters():
|
||||
self.assertFalse(param.requires_grad)
|
||||
|
||||
# Test Swin-Transformer with norm_eval=True
|
||||
model = SwinTransformer(
|
||||
arch='small',
|
||||
norm_eval=True,
|
||||
norm_cfg=dict(type='BN'),
|
||||
stage_cfgs=dict(block_cfgs=dict(norm_cfg=dict(type='BN'))),
|
||||
)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test Swin-Transformer with first stage frozen.
|
||||
frozen_stages = 0
|
||||
model = SwinTransformer(
|
||||
arch='small', frozen_stages=frozen_stages, out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
assert model.patch_embed.training is False
|
||||
for param in model.patch_embed.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(frozen_stages + 1):
|
||||
stage = model.stages[i]
|
||||
for param in stage.parameters():
|
||||
assert param.requires_grad is False
|
||||
for param in model.norm0.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# the second stage should require grad.
|
||||
stage = model.stages[1]
|
||||
for param in stage.parameters():
|
||||
assert param.requires_grad is True
|
||||
for param in model.norm1.parameters():
|
||||
assert param.requires_grad is True
|
||||
|
||||
|
||||
def test_load_checkpoint():
|
||||
model = SwinTransformer(arch='tiny')
|
||||
ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth')
|
||||
|
||||
assert model._version == 2
|
||||
|
||||
# test load v2 checkpoint
|
||||
save_checkpoint(model, ckpt_path)
|
||||
load_checkpoint(model, ckpt_path, strict=True)
|
||||
|
||||
# test load v1 checkpoint
|
||||
setattr(model, 'norm', model.norm3)
|
||||
model._version = 1
|
||||
del model.norm3
|
||||
save_checkpoint(model, ckpt_path)
|
||||
model = SwinTransformer(arch='tiny')
|
||||
load_checkpoint(model, ckpt_path, strict=True)
|
||||
# the second stage should require grad.
|
||||
for i in range(frozen_stages + 1, 4):
|
||||
stage = model.stages[i]
|
||||
for param in stage.parameters():
|
||||
self.assertTrue(param.requires_grad)
|
||||
norm = getattr(model, f'norm{i}')
|
||||
for param in norm.parameters():
|
||||
self.assertTrue(param.requires_grad)
|
||||
|
|
|
@ -1,84 +1,157 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
|
||||
from mmcls.models.backbones import T2T_ViT
|
||||
from .utils import timm_resize_pos_embed
|
||||
|
||||
|
||||
def is_norm(modules):
|
||||
"""Check if is one of the norms."""
|
||||
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||
return True
|
||||
return False
|
||||
class TestT2TViT(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = dict(
|
||||
img_size=224,
|
||||
in_channels=3,
|
||||
embed_dims=384,
|
||||
t2t_cfg=dict(
|
||||
token_dims=64,
|
||||
use_performer=False,
|
||||
),
|
||||
num_layers=14,
|
||||
drop_path_rate=0.1)
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_vit_backbone():
|
||||
|
||||
cfg_ori = dict(
|
||||
img_size=224,
|
||||
in_channels=3,
|
||||
embed_dims=384,
|
||||
t2t_cfg=dict(
|
||||
token_dims=64,
|
||||
use_performer=False,
|
||||
),
|
||||
num_layers=14,
|
||||
layer_cfgs=dict(
|
||||
num_heads=6,
|
||||
feedforward_channels=3 * 384, # mlp_ratio = 3
|
||||
),
|
||||
drop_path_rate=0.1,
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
])
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
# test if use performer
|
||||
cfg = deepcopy(cfg_ori)
|
||||
def test_structure(self):
|
||||
# The performer hasn't been implemented
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['t2t_cfg']['use_performer'] = True
|
||||
T2T_ViT(**cfg)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
T2T_ViT(**cfg)
|
||||
|
||||
# Test T2T-ViT model with input size of 224
|
||||
model = T2T_ViT(**cfg_ori)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
# Test out_indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = {1: 1}
|
||||
with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
|
||||
T2T_ViT(**cfg)
|
||||
cfg['out_indices'] = [0, 15]
|
||||
with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 15'):
|
||||
T2T_ViT(**cfg)
|
||||
|
||||
assert check_norm_state(model.modules(), True)
|
||||
# Test model structure
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = T2T_ViT(**cfg)
|
||||
self.assertEqual(len(model.encoder), 14)
|
||||
dpr_inc = 0.1 / (14 - 1)
|
||||
dpr = 0
|
||||
for layer in model.encoder:
|
||||
self.assertEqual(layer.attn.embed_dims, 384)
|
||||
# The default mlp_ratio is 3
|
||||
self.assertEqual(layer.ffn.feedforward_channels, 384 * 3)
|
||||
self.assertAlmostEqual(layer.attn.out_drop.drop_prob, dpr)
|
||||
self.assertAlmostEqual(layer.ffn.dropout_layer.drop_prob, dpr)
|
||||
dpr += dpr_inc
|
||||
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
patch_token, cls_token = model(imgs)[-1]
|
||||
assert cls_token.shape == (3, 384)
|
||||
assert patch_token.shape == (3, 384, 14, 14)
|
||||
def test_init_weights(self):
|
||||
# test weight init cfg
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['init_cfg'] = [dict(type='TruncNormal', layer='Linear', std=.02)]
|
||||
model = T2T_ViT(**cfg)
|
||||
ori_weight = model.tokens_to_token.project.weight.clone().detach()
|
||||
|
||||
# Test custom arch T2T-ViT without output cls token
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['embed_dims'] = 256
|
||||
cfg['num_layers'] = 16
|
||||
cfg['layer_cfgs'] = dict(num_heads=8, feedforward_channels=1024)
|
||||
cfg['output_cls_token'] = False
|
||||
model.init_weights()
|
||||
initialized_weight = model.tokens_to_token.project.weight
|
||||
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
|
||||
|
||||
model = T2T_ViT(**cfg)
|
||||
patch_token = model(imgs)[-1]
|
||||
assert patch_token.shape == (3, 256, 14, 14)
|
||||
# test load checkpoint
|
||||
pretrain_pos_embed = model.pos_embed.clone().detach()
|
||||
tmpdir = tempfile.gettempdir()
|
||||
checkpoint = os.path.join(tmpdir, 'test.pth')
|
||||
save_checkpoint(model, checkpoint)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = T2T_ViT(**cfg)
|
||||
load_checkpoint(model, checkpoint, strict=True)
|
||||
self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed))
|
||||
|
||||
# Test T2T_ViT with multi out indices
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = T2T_ViT(**cfg)
|
||||
for out in model(imgs):
|
||||
assert out[0].shape == (3, 384, 14, 14)
|
||||
assert out[1].shape == (3, 384)
|
||||
# test load checkpoint with different img_size
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['img_size'] = 384
|
||||
model = T2T_ViT(**cfg)
|
||||
load_checkpoint(model, checkpoint, strict=True)
|
||||
resized_pos_embed = timm_resize_pos_embed(pretrain_pos_embed,
|
||||
model.pos_embed)
|
||||
self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed))
|
||||
|
||||
os.remove(checkpoint)
|
||||
|
||||
def test_forward(self):
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
|
||||
# test with_cls_token=False
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = True
|
||||
with self.assertRaisesRegex(AssertionError, 'but got False'):
|
||||
T2T_ViT(**cfg)
|
||||
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = False
|
||||
model = T2T_ViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
|
||||
|
||||
# test with output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = T2T_ViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (3, 384))
|
||||
|
||||
# test without output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['output_cls_token'] = False
|
||||
model = T2T_ViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
|
||||
|
||||
# Test forward with multi out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = T2T_ViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 3)
|
||||
for out in outs:
|
||||
patch_token, cls_token = out
|
||||
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (3, 384))
|
||||
|
||||
# Test forward with dynamic input size
|
||||
imgs1 = torch.randn(3, 3, 224, 224)
|
||||
imgs2 = torch.randn(3, 3, 256, 256)
|
||||
imgs3 = torch.randn(3, 3, 256, 309)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = T2T_ViT(**cfg)
|
||||
for imgs in [imgs1, imgs2, imgs3]:
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
|
||||
math.ceil(imgs.shape[3] / 16))
|
||||
self.assertEqual(patch_token.shape, (3, 384, *expect_feat_shape))
|
||||
self.assertEqual(cls_token.shape, (3, 384))
|
||||
|
|
|
@ -3,160 +3,181 @@ import math
|
|||
import os
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
|
||||
from mmcls.models.backbones import VisionTransformer
|
||||
from .utils import timm_resize_pos_embed
|
||||
|
||||
|
||||
def is_norm(modules):
|
||||
"""Check if is one of the norms."""
|
||||
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||
return True
|
||||
return False
|
||||
class TestVisionTransformer(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = dict(
|
||||
arch='b', img_size=224, patch_size=16, drop_path_rate=0.1)
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
def test_structure(self):
|
||||
# Test invalid default arch
|
||||
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = 'unknown'
|
||||
VisionTransformer(**cfg)
|
||||
|
||||
# Test invalid custom arch
|
||||
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = {
|
||||
'num_layers': 24,
|
||||
'num_heads': 16,
|
||||
'feedforward_channels': 4096
|
||||
}
|
||||
VisionTransformer(**cfg)
|
||||
|
||||
def test_vit_backbone():
|
||||
# Test custom arch
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = {
|
||||
'embed_dims': 128,
|
||||
'num_layers': 24,
|
||||
'num_heads': 16,
|
||||
'feedforward_channels': 1024
|
||||
}
|
||||
model = VisionTransformer(**cfg)
|
||||
self.assertEqual(model.embed_dims, 128)
|
||||
self.assertEqual(model.num_layers, 24)
|
||||
for layer in model.layers:
|
||||
self.assertEqual(layer.attn.num_heads, 16)
|
||||
self.assertEqual(layer.ffn.feedforward_channels, 1024)
|
||||
|
||||
cfg_ori = dict(
|
||||
arch='b',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_rate=0.1,
|
||||
init_cfg=[
|
||||
# Test out_indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = {1: 1}
|
||||
with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
|
||||
VisionTransformer(**cfg)
|
||||
cfg['out_indices'] = [0, 13]
|
||||
with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'):
|
||||
VisionTransformer(**cfg)
|
||||
|
||||
# Test model structure
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = VisionTransformer(**cfg)
|
||||
self.assertEqual(len(model.layers), 12)
|
||||
dpr_inc = 0.1 / (12 - 1)
|
||||
dpr = 0
|
||||
for layer in model.layers:
|
||||
self.assertEqual(layer.attn.embed_dims, 768)
|
||||
self.assertEqual(layer.attn.num_heads, 12)
|
||||
self.assertEqual(layer.ffn.feedforward_channels, 3072)
|
||||
self.assertAlmostEqual(layer.attn.out_drop.drop_prob, dpr)
|
||||
self.assertAlmostEqual(layer.ffn.dropout_layer.drop_prob, dpr)
|
||||
dpr += dpr_inc
|
||||
|
||||
def test_init_weights(self):
|
||||
# test weight init cfg
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['init_cfg'] = [
|
||||
dict(
|
||||
type='Kaiming',
|
||||
layer='Conv2d',
|
||||
mode='fan_in',
|
||||
nonlinearity='linear')
|
||||
])
|
||||
]
|
||||
model = VisionTransformer(**cfg)
|
||||
ori_weight = model.patch_embed.projection.weight.clone().detach()
|
||||
# The pos_embed is all zero before initialize
|
||||
self.assertTrue(torch.allclose(model.pos_embed, torch.tensor(0.)))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# test invalid arch
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['arch'] = 'unknown'
|
||||
VisionTransformer(**cfg)
|
||||
model.init_weights()
|
||||
initialized_weight = model.patch_embed.projection.weight
|
||||
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
|
||||
self.assertFalse(torch.allclose(model.pos_embed, torch.tensor(0.)))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# test arch without essential keys
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['arch'] = {
|
||||
'num_layers': 24,
|
||||
'num_heads': 16,
|
||||
'feedforward_channels': 4096
|
||||
}
|
||||
VisionTransformer(**cfg)
|
||||
# test load checkpoint
|
||||
pretrain_pos_embed = model.pos_embed.clone().detach()
|
||||
tmpdir = tempfile.gettempdir()
|
||||
checkpoint = os.path.join(tmpdir, 'test.pth')
|
||||
save_checkpoint(model, checkpoint)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = VisionTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint, strict=True)
|
||||
self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed))
|
||||
|
||||
# Test ViT base model with input size of 224
|
||||
# and patch size of 16
|
||||
model = VisionTransformer(**cfg_ori)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
# test load checkpoint with different img_size
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['img_size'] = 384
|
||||
model = VisionTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint, strict=True)
|
||||
resized_pos_embed = timm_resize_pos_embed(pretrain_pos_embed,
|
||||
model.pos_embed)
|
||||
self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed))
|
||||
|
||||
assert check_norm_state(model.modules(), True)
|
||||
os.remove(checkpoint)
|
||||
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
patch_token, cls_token = model(imgs)[-1]
|
||||
assert cls_token.shape == (3, 768)
|
||||
assert patch_token.shape == (3, 768, 14, 14)
|
||||
def test_forward(self):
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
|
||||
# Test custom arch ViT without output cls token
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['arch'] = {
|
||||
'embed_dims': 128,
|
||||
'num_layers': 24,
|
||||
'num_heads': 16,
|
||||
'feedforward_channels': 1024
|
||||
}
|
||||
cfg['output_cls_token'] = False
|
||||
model = VisionTransformer(**cfg)
|
||||
patch_token = model(imgs)[-1]
|
||||
assert patch_token.shape == (3, 128, 14, 14)
|
||||
# test with_cls_token=False
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = True
|
||||
with self.assertRaisesRegex(AssertionError, 'but got False'):
|
||||
VisionTransformer(**cfg)
|
||||
|
||||
# Test ViT with multi out indices
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = VisionTransformer(**cfg)
|
||||
for out in model(imgs):
|
||||
assert out[0].shape == (3, 768, 14, 14)
|
||||
assert out[1].shape == (3, 768)
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = False
|
||||
model = VisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
|
||||
|
||||
# test with output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = VisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (3, 768))
|
||||
|
||||
def timm_resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
||||
# Timm version pos embed resize function.
|
||||
# Refers to https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # noqa:E501
|
||||
ntok_new = posemb_new.shape[1]
|
||||
if num_tokens:
|
||||
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0,
|
||||
num_tokens:]
|
||||
ntok_new -= num_tokens
|
||||
else:
|
||||
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||
if not len(gs_new): # backwards compatibility
|
||||
gs_new = [int(math.sqrt(ntok_new))] * 2
|
||||
assert len(gs_new) >= 2
|
||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
|
||||
-1).permute(0, 3, 1, 2)
|
||||
posemb_grid = F.interpolate(
|
||||
posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
|
||||
posemb_grid = posemb_grid.permute(0, 2, 3,
|
||||
1).reshape(1, gs_new[0] * gs_new[1], -1)
|
||||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
||||
return posemb
|
||||
# test without output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['output_cls_token'] = False
|
||||
model = VisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
|
||||
|
||||
# Test forward with multi out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = VisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 3)
|
||||
for out in outs:
|
||||
patch_token, cls_token = out
|
||||
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (3, 768))
|
||||
|
||||
def test_vit_weight_init():
|
||||
# test weight init cfg
|
||||
pretrain_cfg = dict(
|
||||
arch='b',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
init_cfg=[dict(type='Constant', val=1., layer='Conv2d')])
|
||||
pretrain_model = VisionTransformer(**pretrain_cfg)
|
||||
pretrain_model.init_weights()
|
||||
assert torch.allclose(pretrain_model.patch_embed.projection.weight,
|
||||
torch.tensor(1.))
|
||||
assert pretrain_model.pos_embed.abs().sum() > 0
|
||||
|
||||
pos_embed_weight = pretrain_model.pos_embed.detach()
|
||||
tmpdir = tempfile.gettempdir()
|
||||
checkpoint = os.path.join(tmpdir, 'test.pth')
|
||||
torch.save(pretrain_model.state_dict(), checkpoint)
|
||||
|
||||
# test load checkpoint
|
||||
finetune_cfg = dict(
|
||||
arch='b',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=checkpoint))
|
||||
finetune_model = VisionTransformer(**finetune_cfg)
|
||||
finetune_model.init_weights()
|
||||
assert torch.allclose(finetune_model.pos_embed, pos_embed_weight)
|
||||
|
||||
# test load checkpoint with different img_size
|
||||
finetune_cfg = dict(
|
||||
arch='b',
|
||||
img_size=384,
|
||||
patch_size=16,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=checkpoint))
|
||||
finetune_model = VisionTransformer(**finetune_cfg)
|
||||
finetune_model.init_weights()
|
||||
resized_pos_embed = timm_resize_pos_embed(pos_embed_weight,
|
||||
finetune_model.pos_embed)
|
||||
assert torch.allclose(finetune_model.pos_embed, resized_pos_embed)
|
||||
|
||||
os.remove(checkpoint)
|
||||
# Test forward with dynamic input size
|
||||
imgs1 = torch.randn(3, 3, 224, 224)
|
||||
imgs2 = torch.randn(3, 3, 256, 256)
|
||||
imgs3 = torch.randn(3, 3, 256, 309)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = VisionTransformer(**cfg)
|
||||
for imgs in [imgs1, imgs2, imgs3]:
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
|
||||
math.ceil(imgs.shape[3] / 16))
|
||||
self.assertEqual(patch_token.shape, (3, 768, *expect_feat_shape))
|
||||
self.assertEqual(cls_token.shape, (3, 768))
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def timm_resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
||||
"""Timm version pos embed resize function.
|
||||
|
||||
copied from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
""" # noqa:E501
|
||||
ntok_new = posemb_new.shape[1]
|
||||
if num_tokens:
|
||||
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0,
|
||||
num_tokens:]
|
||||
ntok_new -= num_tokens
|
||||
else:
|
||||
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||
if not len(gs_new): # backwards compatibility
|
||||
gs_new = [int(math.sqrt(ntok_new))] * 2
|
||||
assert len(gs_new) >= 2
|
||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
|
||||
-1).permute(0, 3, 1, 2)
|
||||
posemb_grid = F.interpolate(
|
||||
posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
|
||||
posemb_grid = posemb_grid.permute(0, 2, 3,
|
||||
1).reshape(1, gs_new[0] * gs_new[1], -1)
|
||||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
||||
return posemb
|
|
@ -1,5 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
from unittest import TestCase
|
||||
from unittest.mock import ANY, MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.models.utils.attention import ShiftWindowMSA, WindowMSA
|
||||
|
@ -22,157 +25,177 @@ def get_relative_position_index(window_size):
|
|||
return relative_position_index
|
||||
|
||||
|
||||
def test_window_msa():
|
||||
batch_size = 1
|
||||
num_windows = (4, 4)
|
||||
embed_dims = 96
|
||||
window_size = (7, 7)
|
||||
num_heads = 4
|
||||
attn = WindowMSA(
|
||||
embed_dims=embed_dims, window_size=window_size, num_heads=num_heads)
|
||||
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
|
||||
window_size[0] * window_size[1], embed_dims))
|
||||
class TestWindowMSA(TestCase):
|
||||
|
||||
# test forward
|
||||
output = attn(inputs)
|
||||
assert output.shape == inputs.shape
|
||||
assert attn.relative_position_bias_table.shape == (
|
||||
(2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
||||
def test_forward(self):
|
||||
attn = WindowMSA(embed_dims=96, window_size=(7, 7), num_heads=4)
|
||||
inputs = torch.rand((16, 7 * 7, 96))
|
||||
output = attn(inputs)
|
||||
self.assertEqual(output.shape, inputs.shape)
|
||||
|
||||
# test relative_position_bias_table init
|
||||
attn.init_weights()
|
||||
assert abs(attn.relative_position_bias_table).sum() > 0
|
||||
# test non-square window_size
|
||||
attn = WindowMSA(embed_dims=96, window_size=(6, 7), num_heads=4)
|
||||
inputs = torch.rand((16, 6 * 7, 96))
|
||||
output = attn(inputs)
|
||||
self.assertEqual(output.shape, inputs.shape)
|
||||
|
||||
# test non-square window_size
|
||||
window_size = (6, 7)
|
||||
attn = WindowMSA(
|
||||
embed_dims=embed_dims, window_size=window_size, num_heads=num_heads)
|
||||
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
|
||||
window_size[0] * window_size[1], embed_dims))
|
||||
output = attn(inputs)
|
||||
assert output.shape == inputs.shape
|
||||
def test_relative_pos_embed(self):
|
||||
attn = WindowMSA(embed_dims=96, window_size=(7, 8), num_heads=4)
|
||||
self.assertEqual(attn.relative_position_bias_table.shape,
|
||||
((2 * 7 - 1) * (2 * 8 - 1), 4))
|
||||
# test relative_position_index
|
||||
expected_rel_pos_index = get_relative_position_index((7, 8))
|
||||
self.assertTrue(
|
||||
torch.allclose(attn.relative_position_index,
|
||||
expected_rel_pos_index))
|
||||
|
||||
# test relative_position_index
|
||||
expected_rel_pos_index = get_relative_position_index(window_size)
|
||||
assert (attn.relative_position_index == expected_rel_pos_index).all()
|
||||
# test default init
|
||||
self.assertTrue(
|
||||
torch.allclose(attn.relative_position_bias_table,
|
||||
torch.tensor(0.)))
|
||||
attn.init_weights()
|
||||
self.assertFalse(
|
||||
torch.allclose(attn.relative_position_bias_table,
|
||||
torch.tensor(0.)))
|
||||
|
||||
# test qkv_bias=True
|
||||
attn = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
window_size=window_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=True)
|
||||
assert attn.qkv.bias.shape == (embed_dims * 3, )
|
||||
def test_qkv_bias(self):
|
||||
# test qkv_bias=True
|
||||
attn = WindowMSA(
|
||||
embed_dims=96, window_size=(7, 7), num_heads=4, qkv_bias=True)
|
||||
self.assertEqual(attn.qkv.bias.shape, (96 * 3, ))
|
||||
|
||||
# test qkv_bias=False
|
||||
attn = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
window_size=window_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=False)
|
||||
assert attn.qkv.bias is None
|
||||
# test qkv_bias=False
|
||||
attn = WindowMSA(
|
||||
embed_dims=96, window_size=(7, 7), num_heads=4, qkv_bias=False)
|
||||
self.assertIsNone(attn.qkv.bias)
|
||||
|
||||
# test default qk_scale
|
||||
attn = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
window_size=window_size,
|
||||
num_heads=num_heads,
|
||||
qk_scale=None)
|
||||
head_dims = embed_dims // num_heads
|
||||
assert np.isclose(attn.scale, head_dims**-0.5)
|
||||
def tets_qk_scale(self):
|
||||
# test default qk_scale
|
||||
attn = WindowMSA(
|
||||
embed_dims=96, window_size=(7, 7), num_heads=4, qk_scale=None)
|
||||
head_dims = 96 // 4
|
||||
self.assertAlmostEqual(attn.scale, head_dims**-0.5)
|
||||
|
||||
# test specified qk_scale
|
||||
attn = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
window_size=window_size,
|
||||
num_heads=num_heads,
|
||||
qk_scale=0.3)
|
||||
assert attn.scale == 0.3
|
||||
# test specified qk_scale
|
||||
attn = WindowMSA(
|
||||
embed_dims=96, window_size=(7, 7), num_heads=4, qk_scale=0.3)
|
||||
self.assertEqual(attn.scale, 0.3)
|
||||
|
||||
# test attn_drop
|
||||
attn = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
window_size=window_size,
|
||||
num_heads=num_heads,
|
||||
attn_drop=1.0)
|
||||
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
|
||||
window_size[0] * window_size[1], embed_dims))
|
||||
# drop all attn output, output shuold be equal to proj.bias
|
||||
assert torch.allclose(attn(inputs), attn.proj.bias)
|
||||
def test_attn_drop(self):
|
||||
inputs = torch.rand(16, 7 * 7, 96)
|
||||
attn = WindowMSA(
|
||||
embed_dims=96, window_size=(7, 7), num_heads=4, attn_drop=1.0)
|
||||
# drop all attn output, output shuold be equal to proj.bias
|
||||
self.assertTrue(torch.allclose(attn(inputs), attn.proj.bias))
|
||||
|
||||
# test prob_drop
|
||||
attn = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
window_size=window_size,
|
||||
num_heads=num_heads,
|
||||
proj_drop=1.0)
|
||||
assert (attn(inputs) == 0).all()
|
||||
def test_prob_drop(self):
|
||||
inputs = torch.rand(16, 7 * 7, 96)
|
||||
attn = WindowMSA(
|
||||
embed_dims=96, window_size=(7, 7), num_heads=4, proj_drop=1.0)
|
||||
self.assertTrue(torch.allclose(attn(inputs), torch.tensor(0.)))
|
||||
|
||||
def test_mask(self):
|
||||
inputs = torch.rand(16, 7 * 7, 96)
|
||||
attn = WindowMSA(embed_dims=96, window_size=(7, 7), num_heads=4)
|
||||
mask = torch.zeros((4, 49, 49))
|
||||
# Mask the first column
|
||||
mask[:, 0, :] = -100
|
||||
mask[:, :, 0] = -100
|
||||
outs = attn(inputs, mask=mask)
|
||||
inputs[:, 0, :].normal_()
|
||||
outs_with_mask = attn(inputs, mask=mask)
|
||||
torch.testing.assert_allclose(outs[:, 1:, :], outs_with_mask[:, 1:, :])
|
||||
|
||||
|
||||
def test_shift_window_msa():
|
||||
batch_size = 1
|
||||
embed_dims = 96
|
||||
input_resolution = (14, 14)
|
||||
num_heads = 4
|
||||
window_size = 7
|
||||
class TestShiftWindowMSA(TestCase):
|
||||
|
||||
# test forward
|
||||
attn = ShiftWindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size)
|
||||
inputs = torch.rand(
|
||||
(batch_size, input_resolution[0] * input_resolution[1], embed_dims))
|
||||
output = attn(inputs)
|
||||
assert output.shape == (inputs.shape)
|
||||
assert attn.w_msa.relative_position_bias_table.shape == ((2 * window_size -
|
||||
1)**2, num_heads)
|
||||
def test_forward(self):
|
||||
inputs = torch.rand((1, 14 * 14, 96))
|
||||
attn = ShiftWindowMSA(embed_dims=96, window_size=7, num_heads=4)
|
||||
output = attn(inputs, (14, 14))
|
||||
self.assertEqual(output.shape, inputs.shape)
|
||||
self.assertEqual(attn.w_msa.relative_position_bias_table.shape,
|
||||
((2 * 7 - 1)**2, 4))
|
||||
|
||||
# test forward with shift_size
|
||||
attn = ShiftWindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=1)
|
||||
output = attn(inputs)
|
||||
assert output.shape == (inputs.shape)
|
||||
# test forward with shift_size
|
||||
attn = ShiftWindowMSA(
|
||||
embed_dims=96, window_size=7, num_heads=4, shift_size=3)
|
||||
output = attn(inputs, (14, 14))
|
||||
assert output.shape == (inputs.shape)
|
||||
|
||||
# test relative_position_bias_table init
|
||||
attn.init_weights()
|
||||
assert abs(attn.w_msa.relative_position_bias_table).sum() > 0
|
||||
# test irregular input shape
|
||||
input_resolution = (19, 18)
|
||||
attn = ShiftWindowMSA(embed_dims=96, num_heads=4, window_size=7)
|
||||
inputs = torch.rand((1, 19 * 18, 96))
|
||||
output = attn(inputs, input_resolution)
|
||||
assert output.shape == (inputs.shape)
|
||||
|
||||
# test dropout_layer
|
||||
attn = ShiftWindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=0.5))
|
||||
torch.manual_seed(0)
|
||||
output = attn(inputs)
|
||||
assert (output == 0).all()
|
||||
# test wrong input_resolution
|
||||
input_resolution = (14, 14)
|
||||
attn = ShiftWindowMSA(embed_dims=96, num_heads=4, window_size=7)
|
||||
inputs = torch.rand((1, 14 * 14, 96))
|
||||
with pytest.raises(AssertionError):
|
||||
attn(inputs, (14, 15))
|
||||
|
||||
# test auto_pad
|
||||
input_resolution = (19, 18)
|
||||
attn = ShiftWindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
auto_pad=True)
|
||||
assert attn.pad_r == 3
|
||||
assert attn.pad_b == 2
|
||||
def test_pad_small_map(self):
|
||||
# test pad_small_map=True
|
||||
inputs = torch.rand((1, 6 * 7, 96))
|
||||
attn = ShiftWindowMSA(
|
||||
embed_dims=96,
|
||||
window_size=7,
|
||||
num_heads=4,
|
||||
shift_size=3,
|
||||
pad_small_map=True)
|
||||
attn.get_attn_mask = MagicMock(wraps=attn.get_attn_mask)
|
||||
output = attn(inputs, (6, 7))
|
||||
self.assertEqual(output.shape, inputs.shape)
|
||||
attn.get_attn_mask.assert_called_once_with((7, 7),
|
||||
window_size=7,
|
||||
shift_size=3,
|
||||
device=ANY)
|
||||
|
||||
# test small input_resolution
|
||||
input_resolution = (5, 6)
|
||||
attn = ShiftWindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=3,
|
||||
auto_pad=True)
|
||||
assert attn.window_size == 5
|
||||
assert attn.shift_size == 0
|
||||
# test pad_small_map=False
|
||||
inputs = torch.rand((1, 6 * 7, 96))
|
||||
attn = ShiftWindowMSA(
|
||||
embed_dims=96,
|
||||
window_size=7,
|
||||
num_heads=4,
|
||||
shift_size=3,
|
||||
pad_small_map=False)
|
||||
with self.assertRaisesRegex(AssertionError, r'the window size \(7\)'):
|
||||
attn(inputs, (6, 7))
|
||||
|
||||
# test pad_small_map=False, and the input size equals to window size
|
||||
inputs = torch.rand((1, 7 * 7, 96))
|
||||
attn.get_attn_mask = MagicMock(wraps=attn.get_attn_mask)
|
||||
output = attn(inputs, (7, 7))
|
||||
self.assertEqual(output.shape, inputs.shape)
|
||||
attn.get_attn_mask.assert_called_once_with((7, 7),
|
||||
window_size=7,
|
||||
shift_size=0,
|
||||
device=ANY)
|
||||
|
||||
def test_drop_layer(self):
|
||||
inputs = torch.rand((1, 14 * 14, 96))
|
||||
attn = ShiftWindowMSA(
|
||||
embed_dims=96,
|
||||
window_size=7,
|
||||
num_heads=4,
|
||||
dropout_layer=dict(type='Dropout', drop_prob=1.0))
|
||||
attn.init_weights()
|
||||
# drop all attn output, output shuold be equal to proj.bias
|
||||
self.assertTrue(
|
||||
torch.allclose(attn(inputs, (14, 14)), torch.tensor(0.)))
|
||||
|
||||
def test_deprecation(self):
|
||||
# test deprecated arguments
|
||||
with pytest.warns(DeprecationWarning):
|
||||
ShiftWindowMSA(
|
||||
embed_dims=96,
|
||||
num_heads=4,
|
||||
window_size=7,
|
||||
input_resolution=(14, 14))
|
||||
|
||||
with pytest.warns(DeprecationWarning):
|
||||
ShiftWindowMSA(
|
||||
embed_dims=96, num_heads=4, window_size=7, auto_pad=True)
|
||||
|
|
Loading…
Reference in New Issue