EasyCV/easycv/models/backbones/shuffle_transformer.py

546 lines
17 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the code is borrowed from:
# https://github.com/mulinmeng/Shuffle-Transformer/blob/main/models/shuffle_transformer.py".
import torch
from einops import rearrange
from timm.models.layers import DropPath, trunc_normal_
from torch import nn
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
from ..registry import BACKBONES
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.ReLU6,
drop=0.,
stride=False):
super().__init__()
self.stride = stride
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)
self.drop = nn.Dropout(drop, inplace=True)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self,
dim,
num_heads,
window_size=1,
shuffle=False,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
relative_pos_embedding=False):
super().__init__()
self.num_heads = num_heads
self.relative_pos_embedding = relative_pos_embedding
head_dim = dim // self.num_heads
self.ws = window_size
self.shuffle = shuffle
self.scale = qk_scale or head_dim**-0.5
self.to_qkv = nn.Conv2d(dim, dim * 3, 1, bias=False)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Conv2d(dim, dim, 1)
self.proj_drop = nn.Dropout(proj_drop)
if self.relative_pos_embedding:
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1),
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.ws)
coords_w = torch.arange(self.ws)
coords = torch.stack(torch.meshgrid([coords_h,
coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :,
None] - coords_flatten[:,
None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.ws - 1 # shift to start from 0
relative_coords[:, :, 1] += self.ws - 1
relative_coords[:, :, 0] *= 2 * self.ws - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer('relative_position_index',
relative_position_index)
trunc_normal_(self.relative_position_bias_table, std=.02)
print('The relative_pos_embedding is used')
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
if self.shuffle:
q, k, v = rearrange(
qkv,
'b (qkv h d) (ws1 hh) (ws2 ww) -> qkv (b hh ww) h (ws1 ws2) d',
h=self.num_heads,
qkv=3,
ws1=self.ws,
ws2=self.ws)
else:
q, k, v = rearrange(
qkv,
'b (qkv h d) (hh ws1) (ww ws2) -> qkv (b hh ww) h (ws1 ws2) d',
h=self.num_heads,
qkv=3,
ws1=self.ws,
ws2=self.ws)
dots = (q @ k.transpose(-2, -1)) * self.scale
if self.relative_pos_embedding:
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.ws * self.ws, self.ws * self.ws, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
dots += relative_position_bias.unsqueeze(0)
attn = dots.softmax(dim=-1)
out = attn @ v
if self.shuffle:
out = rearrange(
out,
'(b hh ww) h (ws1 ws2) d -> b (h d) (ws1 hh) (ws2 ww)',
h=self.num_heads,
b=b,
hh=h // self.ws,
ws1=self.ws,
ws2=self.ws)
else:
out = rearrange(
out,
'(b hh ww) h (ws1 ws2) d -> b (h d) (hh ws1) (ww ws2)',
h=self.num_heads,
b=b,
hh=h // self.ws,
ws1=self.ws,
ws2=self.ws)
out = self.proj(out)
out = self.proj_drop(out)
return out
class Block(nn.Module):
def __init__(self,
dim,
out_dim,
num_heads,
window_size=1,
shuffle=False,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.ReLU6,
norm_layer=nn.BatchNorm2d,
stride=False,
relative_pos_embedding=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
window_size=window_size,
shuffle=shuffle,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
relative_pos_embedding=relative_pos_embedding)
self.local = nn.Conv2d(
dim,
dim,
window_size,
1,
window_size // 2,
groups=dim,
bias=qkv_bias)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
out_features=out_dim,
act_layer=act_layer,
drop=drop,
stride=stride)
self.norm3 = norm_layer(dim)
print(
'input dim={}, output dim={}, stride={}, expand={}, num_heads={}'.
format(dim, out_dim, stride, shuffle, num_heads))
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.local(self.norm2(x)) # local connection
x = x + self.drop_path(self.mlp(self.norm3(x)))
return x
class PatchMerging(nn.Module):
def __init__(self, dim, out_dim, norm_layer=nn.BatchNorm2d):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.norm = norm_layer(dim)
self.reduction = nn.Conv2d(dim, out_dim, 2, 2, 0, bias=False)
def forward(self, x):
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f'input dim={self.dim}, out dim={self.out_dim}'
class StageModule(nn.Module):
def __init__(self,
layers,
dim,
out_dim,
num_heads,
window_size=1,
shuffle=True,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.ReLU6,
norm_layer=nn.BatchNorm2d,
relative_pos_embedding=False):
super().__init__()
assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted block.'
if dim != out_dim:
self.patch_partition = PatchMerging(dim, out_dim)
else:
self.patch_partition = None
num = layers // 2
self.layers = nn.ModuleList([])
for idx in range(num):
the_last = (idx == num - 1)
self.layers.append(
nn.ModuleList([
Block(
dim=out_dim,
out_dim=out_dim,
num_heads=num_heads,
window_size=window_size,
shuffle=False,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path,
relative_pos_embedding=relative_pos_embedding),
Block(
dim=out_dim,
out_dim=out_dim,
num_heads=num_heads,
window_size=window_size,
shuffle=shuffle,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path,
relative_pos_embedding=relative_pos_embedding)
]))
def forward(self, x):
if self.patch_partition:
x = self.patch_partition(x)
for regular_block, shifted_block in self.layers:
x = regular_block(x)
x = shifted_block(x)
return x
class PatchEmbedding(nn.Module):
def __init__(self, inter_channel=32, out_channels=48):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, inter_channel, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(inter_channel), nn.ReLU6(inplace=True))
self.conv2 = nn.Sequential(
nn.Conv2d(
inter_channel,
out_channels,
kernel_size=3,
stride=2,
padding=1), nn.BatchNorm2d(out_channels),
nn.ReLU6(inplace=True))
self.conv3 = nn.Conv2d(
out_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x = self.conv3(self.conv2(self.conv1(x)))
return x
@BACKBONES.register_module
class ShuffleTransformer(nn.Module):
def __init__(self,
img_size=224,
in_chans=3,
num_classes=1000,
token_dim=32,
embed_dim=96,
mlp_ratio=4.,
layers=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
relative_pos_embedding=True,
shuffle=True,
window_size=7,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
has_pos_embed=False,
**kwargs):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.has_pos_embed = has_pos_embed
dims = [i * 32 for i in num_heads]
self.to_token = PatchEmbedding(
inter_channel=token_dim, out_channels=embed_dim)
num_patches = (img_size * img_size) // 16
if self.has_pos_embed:
raise NotImplementedError
# self.pos_embed = nn.Parameter(
# data=get_sinusoid_encoding(
# n_position=num_patches, d_hid=embed_dim),
# requires_grad=False)
# self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 4)
] # stochastic depth decay rule
self.stage1 = StageModule(
layers[0],
embed_dim,
dims[0],
num_heads[0],
window_size=window_size,
shuffle=shuffle,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[0],
relative_pos_embedding=relative_pos_embedding)
self.stage2 = StageModule(
layers[1],
dims[0],
dims[1],
num_heads[1],
window_size=window_size,
shuffle=shuffle,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[1],
relative_pos_embedding=relative_pos_embedding)
self.stage3 = StageModule(
layers[2],
dims[1],
dims[2],
num_heads[2],
window_size=window_size,
shuffle=shuffle,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[2],
relative_pos_embedding=relative_pos_embedding)
self.stage4 = StageModule(
layers[3],
dims[2],
dims[3],
num_heads[3],
window_size=window_size,
shuffle=shuffle,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[3],
relative_pos_embedding=relative_pos_embedding)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# Classifier head
self.head = nn.Linear(
dims[3], num_classes) if num_classes > 0 else nn.Identity()
def init_weights(self):
for m in self.modules():
if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.Linear, nn.Conv2d)):
trunc_normal_(m.weight, std=.02)
if isinstance(m,
(nn.Linear, nn.Conv2d)) and m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.to_token(x)
b, c, h, w = x.shape
if self.has_pos_embed:
x = x + self.pos_embed.view(1, h, w, c).permute(0, 3, 1, 2)
x = self.pos_drop(x)
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return [x]
def shuffletrans_base_p4_w7_224(pretrained=False, **kwargs):
model = ShuffleTransformer(
img_size=224,
in_chans=3,
num_classes=kwargs['num_classes'],
token_dim=32,
embed_dim=128,
mlp_ratio=4.,
layers=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
relative_pos_embedding=True,
shuffle=True,
window_size=7,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.5,
has_pos_embed=False)
return model
def shuffletrans_small_p4_w7_224(pretrained=False, **kwargs):
model = ShuffleTransformer(
img_size=224,
in_chans=3,
num_classes=kwargs['num_classes'],
token_dim=32,
embed_dim=96,
mlp_ratio=4.,
layers=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
relative_pos_embedding=True,
shuffle=True,
window_size=7,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
has_pos_embed=False)
return model
def shuffletrans_tiny_p4_w7_224(pretrained=False, **kwargs):
model = ShuffleTransformer(
img_size=224,
in_chans=3,
num_classes=kwargs['num_classes'],
token_dim=32,
embed_dim=96,
mlp_ratio=4.,
layers=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
relative_pos_embedding=True,
shuffle=True,
window_size=7,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
has_pos_embed=False)
return model