mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Refactor] Using mmcv transformer bricks to refactor vit. (#571)
* [Refactor] Using mmcv bricks to refactor vit * Follow the vit code structure from mmclassification * Add MMCV install into CI system. * Add to 'Install MMCV' CI item * Add 'Install MMCV_CPU' and 'Install MMCV_GPU CI' items * Fix & Add 1. Fix low code coverage of vit.py; 2. Remove HybirdEmbed; 3. Fix doc string of VisionTransformer; * Add helpers unit test. * Add converter to convert vit pretrain weights from timm style to mmcls style. * Clean some rebundant code and refactor init 1. Use timm style init_weights; 2. Remove to_xtuple and trunc_norm_; * Add comments for VisionTransformer.init_weights() * Add arg: pretrain_style to choose timm or mmcls vit pretrain weights.
This commit is contained in:
parent
76e0d673e9
commit
c01abb4f30
@ -1,294 +1,257 @@
|
|||||||
"""Modified from https://github.com/rwightman/pytorch-image-
|
|
||||||
models/blob/master/timm/models/vision_transformer.py."""
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint as cp
|
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
|
||||||
from mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer,
|
kaiming_init, normal_init, trunc_normal_init)
|
||||||
constant_init, kaiming_init, normal_init)
|
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||||
from mmcv.runner import BaseModule, _load_checkpoint
|
from mmcv.runner import _load_checkpoint
|
||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
from torch.nn.modules.utils import _pair as to_2tuple
|
||||||
|
|
||||||
from mmseg.utils import get_root_logger
|
from mmseg.utils import get_root_logger
|
||||||
from ..builder import BACKBONES
|
from ..builder import BACKBONES
|
||||||
from ..utils import DropPath, trunc_normal_
|
from ..utils import vit_convert
|
||||||
|
|
||||||
|
|
||||||
class Mlp(nn.Module):
|
class TransformerEncoderLayer(BaseModule):
|
||||||
"""MLP layer for Encoder block.
|
"""Implements one encoder layer in Vision Transformer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_features(int): Input dimension for the first fully
|
embed_dims (int): The feature dimension
|
||||||
connected layer.
|
num_heads (int): Parallel attention heads
|
||||||
hidden_features(int): Output dimension for the first fully
|
feedforward_channels (int): The hidden dimension for FFNs
|
||||||
connected layer.
|
drop_rate (float): Probability of an element to be zeroed
|
||||||
out_features(int): Output dementsion for the second fully
|
after the feed forward layer. Default 0.0
|
||||||
connected layer.
|
attn_drop_rate (float): The drop out rate for attention layer.
|
||||||
act_cfg(dict): Config dict for activation layer.
|
Default 0.0
|
||||||
Default: dict(type='GELU').
|
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||||
drop(float): Drop rate for the dropout layer. Dropout rate has
|
num_fcs (int): The number of fully-connected layers for FFNs. Default 2
|
||||||
to be between 0 and 1. Default: 0.
|
qkv_bias (bool): enable bias for qkv if True. Default True
|
||||||
|
act_cfg (dict): The activation config for FFNs. Defalut GELU
|
||||||
|
norm_cfg (dict): Config dict for normalization layer. Default
|
||||||
|
layer normalization
|
||||||
|
batch_first (bool): Key, Query and Value are shape of
|
||||||
|
(batch, n, embed_dim)
|
||||||
|
or (n, batch, embed_dim). Default to False.
|
||||||
|
init_cfg (dict, optional): Initialization config dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features,
|
embed_dims,
|
||||||
hidden_features=None,
|
|
||||||
out_features=None,
|
|
||||||
act_cfg=dict(type='GELU'),
|
|
||||||
drop=0.):
|
|
||||||
super(Mlp, self).__init__()
|
|
||||||
out_features = out_features or in_features
|
|
||||||
hidden_features = hidden_features or in_features
|
|
||||||
self.fc1 = Linear(in_features, hidden_features)
|
|
||||||
self.act = build_activation_layer(act_cfg)
|
|
||||||
self.fc2 = Linear(hidden_features, out_features)
|
|
||||||
self.drop = nn.Dropout(drop)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.fc1(x)
|
|
||||||
x = self.act(x)
|
|
||||||
x = self.drop(x)
|
|
||||||
x = self.fc2(x)
|
|
||||||
x = self.drop(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
|
||||||
"""Attention layer for Encoder block.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim (int): Dimension for the input vector.
|
|
||||||
num_heads (int): Number of parallel attention heads.
|
|
||||||
qkv_bias (bool): Enable bias for qkv if True. Default: False.
|
|
||||||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
|
||||||
attn_drop (float): Drop rate for attention output weights.
|
|
||||||
Default: 0.
|
|
||||||
proj_drop (float): Drop rate for output weights. Default: 0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
dim,
|
|
||||||
num_heads=8,
|
|
||||||
qkv_bias=False,
|
|
||||||
qk_scale=None,
|
|
||||||
attn_drop=0.,
|
|
||||||
proj_drop=0.):
|
|
||||||
super(Attention, self).__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
head_dim = dim // num_heads
|
|
||||||
self.scale = qk_scale or head_dim**-0.5
|
|
||||||
|
|
||||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
|
||||||
self.proj = Linear(dim, dim)
|
|
||||||
self.proj_drop = nn.Dropout(proj_drop)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
b, n, c = x.shape
|
|
||||||
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads,
|
|
||||||
c // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
||||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
||||||
|
|
||||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
||||||
attn = attn.softmax(dim=-1)
|
|
||||||
attn = self.attn_drop(attn)
|
|
||||||
|
|
||||||
x = (attn @ v).transpose(1, 2).reshape(b, n, c)
|
|
||||||
x = self.proj(x)
|
|
||||||
x = self.proj_drop(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
|
||||||
"""Implements encoder block with residual connection.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim (int): The feature dimension.
|
|
||||||
num_heads (int): Number of parallel attention heads.
|
|
||||||
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
|
|
||||||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
|
||||||
drop (float): Drop rate for mlp output weights. Default: 0.
|
|
||||||
attn_drop (float): Drop rate for attention output weights.
|
|
||||||
Default: 0.
|
|
||||||
proj_drop (float): Drop rate for attn layer output weights.
|
|
||||||
Default: 0.
|
|
||||||
drop_path (float): Drop rate for paths of model.
|
|
||||||
Default: 0.
|
|
||||||
act_cfg (dict): Config dict for activation layer.
|
|
||||||
Default: dict(type='GELU').
|
|
||||||
norm_cfg (dict): Config dict for normalization layer.
|
|
||||||
Default: dict(type='LN', requires_grad=True).
|
|
||||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
|
||||||
memory while slowing down the training speed. Default: False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
dim,
|
|
||||||
num_heads,
|
num_heads,
|
||||||
mlp_ratio=4,
|
feedforward_channels,
|
||||||
qkv_bias=False,
|
drop_rate=0.,
|
||||||
qk_scale=None,
|
attn_drop_rate=0.,
|
||||||
drop=0.,
|
drop_path_rate=0.,
|
||||||
attn_drop=0.,
|
num_fcs=2,
|
||||||
proj_drop=0.,
|
qkv_bias=True,
|
||||||
drop_path=0.,
|
|
||||||
act_cfg=dict(type='GELU'),
|
act_cfg=dict(type='GELU'),
|
||||||
norm_cfg=dict(type='LN', eps=1e-6),
|
norm_cfg=dict(type='LN'),
|
||||||
with_cp=False):
|
batch_first=False):
|
||||||
super(Block, self).__init__()
|
super(TransformerEncoderLayer, self).__init__()
|
||||||
self.with_cp = with_cp
|
|
||||||
_, self.norm1 = build_norm_layer(norm_cfg, dim)
|
self.norm1_name, norm1 = build_norm_layer(
|
||||||
self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop,
|
norm_cfg, embed_dims, postfix=1)
|
||||||
proj_drop)
|
self.add_module(self.norm1_name, norm1)
|
||||||
self.drop_path = DropPath(
|
|
||||||
drop_path) if drop_path > 0. else nn.Identity()
|
self.attn = MultiheadAttention(
|
||||||
_, self.norm2 = build_norm_layer(norm_cfg, dim)
|
embed_dims=embed_dims,
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
num_heads=num_heads,
|
||||||
self.mlp = Mlp(
|
attn_drop=attn_drop_rate,
|
||||||
in_features=dim,
|
proj_drop=drop_rate,
|
||||||
hidden_features=mlp_hidden_dim,
|
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||||
act_cfg=act_cfg,
|
batch_first=batch_first,
|
||||||
drop=drop)
|
bias=qkv_bias)
|
||||||
|
|
||||||
|
self.norm2_name, norm2 = build_norm_layer(
|
||||||
|
norm_cfg, embed_dims, postfix=2)
|
||||||
|
self.add_module(self.norm2_name, norm2)
|
||||||
|
|
||||||
|
self.ffn = FFN(
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
feedforward_channels=feedforward_channels,
|
||||||
|
num_fcs=num_fcs,
|
||||||
|
ffn_drop=drop_rate,
|
||||||
|
dropout_layer=None,
|
||||||
|
act_cfg=act_cfg)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def norm1(self):
|
||||||
|
return getattr(self, self.norm1_name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def norm2(self):
|
||||||
|
return getattr(self, self.norm2_name)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
x = self.attn(self.norm1(x), identity=x)
|
||||||
def _inner_forward(x):
|
x = self.ffn(self.norm2(x), identity=x)
|
||||||
out = x + self.drop_path(self.attn(self.norm1(x)))
|
return x
|
||||||
out = out + self.drop_path(self.mlp(self.norm2(out)))
|
|
||||||
return out
|
|
||||||
|
|
||||||
if self.with_cp and x.requires_grad:
|
|
||||||
out = cp.checkpoint(_inner_forward, x)
|
|
||||||
else:
|
|
||||||
out = _inner_forward(x)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbed(nn.Module):
|
# Modified from pytorch-image-models
|
||||||
|
class PatchEmbed(BaseModule):
|
||||||
"""Image to Patch Embedding.
|
"""Image to Patch Embedding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img_size (int | tuple): Input image size.
|
img_size (int | tuple): The size of input image.
|
||||||
default: 224.
|
patch_size (int): The size of one patch
|
||||||
patch_size (int): Width and height for a patch.
|
in_channels (int): The num of input channels.
|
||||||
default: 16.
|
embed_dim (int): The dimensions of embedding.
|
||||||
in_channels (int): Input channels for images. Default: 3.
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||||
embed_dim (int): The embedding dimension. Default: 768.
|
conv_cfg (dict, optional): The config dict for conv layers.
|
||||||
|
Default: None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
img_size=224,
|
img_size=224,
|
||||||
patch_size=16,
|
patch_size=16,
|
||||||
in_channels=3,
|
in_channels=3,
|
||||||
embed_dim=768):
|
embed_dim=768,
|
||||||
|
norm_cfg=None,
|
||||||
|
conv_cfg=None):
|
||||||
super(PatchEmbed, self).__init__()
|
super(PatchEmbed, self).__init__()
|
||||||
if isinstance(img_size, int):
|
|
||||||
self.img_size = (img_size, img_size)
|
self.img_size = img_size
|
||||||
elif isinstance(img_size, tuple):
|
self.patch_size = to_2tuple(patch_size)
|
||||||
self.img_size = img_size
|
|
||||||
|
patches_resolution = [
|
||||||
|
img_size[0] // self.patch_size[0],
|
||||||
|
img_size[1] // self.patch_size[1]
|
||||||
|
]
|
||||||
|
num_patches = patches_resolution[0] * patches_resolution[1]
|
||||||
|
self.patches_resolution = patches_resolution
|
||||||
|
self.num_patches = num_patches
|
||||||
|
|
||||||
|
# Use conv layer to embed
|
||||||
|
self.projection = build_conv_layer(
|
||||||
|
conv_cfg,
|
||||||
|
in_channels,
|
||||||
|
embed_dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size)
|
||||||
|
|
||||||
|
if norm_cfg is not None:
|
||||||
|
self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
|
||||||
else:
|
else:
|
||||||
raise TypeError('img_size must be type of int or tuple')
|
self.norm = None
|
||||||
h, w = self.img_size
|
|
||||||
self.patch_size = (patch_size, patch_size)
|
|
||||||
self.num_patches = (h // patch_size) * (w // patch_size)
|
|
||||||
self.proj = Conv2d(
|
|
||||||
in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.proj(x).flatten(2).transpose(1, 2)
|
B, C, H, W = x.shape
|
||||||
|
# FIXME look at relaxing size constraints
|
||||||
|
# assert H == self.img_size[0] and W == self.img_size[1], \
|
||||||
|
# f"Input image size ({H}*{W}) doesn't " \
|
||||||
|
# f'match model ({self.img_size[0]}*{self.img_size[1]}).'
|
||||||
|
# The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
|
||||||
|
x = self.projection(x).flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
if self.norm is not None:
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@BACKBONES.register_module()
|
||||||
class VisionTransformer(BaseModule):
|
class VisionTransformer(BaseModule):
|
||||||
"""Vision transformer backbone.
|
"""Vision Transformer.
|
||||||
|
|
||||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for
|
A PyTorch implement of : `An Image is Worth 16x16 Words:
|
||||||
Image Recognition at Scale` - https://arxiv.org/abs/2010.11929
|
Transformers for Image Recognition at Scale` -
|
||||||
|
https://arxiv.org/abs/2010.11929
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img_size (tuple): input image size. Default: (224, 224).
|
img_size (int | tuple): Input image size. Default: 224.
|
||||||
patch_size (int, tuple): patch size. Default: 16.
|
patch_size (int): The patch size. Default: 16.
|
||||||
in_channels (int): number of input channels. Default: 3.
|
in_channels (int): Number of input channels. Default: 3.
|
||||||
embed_dim (int): embedding dimension. Default: 768.
|
embed_dims (int): embedding dimension. Default: 768.
|
||||||
depth (int): depth of transformer. Default: 12.
|
num_layers (int): depth of transformer. Default: 12.
|
||||||
num_heads (int): number of attention heads. Default: 12.
|
num_heads (int): number of attention heads. Default: 12.
|
||||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||||
Default: 4.
|
Default: 4.
|
||||||
out_indices (list | tuple | int): Output from which stages.
|
out_indices (list | tuple | int): Output from which stages.
|
||||||
Default: -1.
|
Default: -1.
|
||||||
qkv_bias (bool): enable bias for qkv if True. Default: True.
|
qkv_bias (bool): enable bias for qkv if True. Default: True.
|
||||||
qk_scale (float): override default qk scale of head_dim ** -0.5 if set.
|
drop_rate (float): Probability of an element to be zeroed.
|
||||||
drop_rate (float): dropout rate. Default: 0.
|
Default 0.0
|
||||||
attn_drop_rate (float): attention dropout rate. Default: 0.
|
attn_drop_rate (float): The drop out rate for attention layer.
|
||||||
drop_path_rate (float): Rate of DropPath. Default: 0.
|
Default 0.0
|
||||||
|
drop_path_rate (float): stochastic depth rate. Default 0.0
|
||||||
|
with_cls_token (bool): If concatenating class token into image tokens
|
||||||
|
as transformer input. Default: True.
|
||||||
norm_cfg (dict): Config dict for normalization layer.
|
norm_cfg (dict): Config dict for normalization layer.
|
||||||
Default: dict(type='LN', eps=1e-6, requires_grad=True).
|
Default: dict(type='LN')
|
||||||
act_cfg (dict): Config dict for activation layer.
|
act_cfg (dict): The activation config for FFNs.
|
||||||
Default: dict(type='GELU').
|
Defalut: dict(type='GELU').
|
||||||
|
final_norm (bool): Whether to add a additional layer to normalize
|
||||||
|
final feature map. Default: False.
|
||||||
|
interpolate_mode (str): Select the interpolate mode for position
|
||||||
|
embeding vector resize. Default: bicubic.
|
||||||
|
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||||
|
Default: 2.
|
||||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||||
and its variants only. Default: False.
|
and its variants only. Default: False.
|
||||||
final_norm (bool): Whether to add a additional layer to normalize
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||||
final feature map. Default: False.
|
some memory while slowing down the training speed. Default: False.
|
||||||
out_reshape (str): Select the output format of feature information.
|
pretrain_style (str): Choose to use timm or mmcls pretrain weights.
|
||||||
Default: NCHW.
|
Default: timm.
|
||||||
interpolate_mode (str): Select the interpolate mode for position
|
|
||||||
embeding vector resize. Default: bicubic.
|
|
||||||
with_cls_token (bool): If concatenating class token into image tokens
|
|
||||||
as transformer input. Default: True.
|
|
||||||
with_cp (bool): Use checkpoint or not. Using checkpoint
|
|
||||||
will save some memory while slowing down the training speed.
|
|
||||||
Default: False.
|
|
||||||
pretrained (str, optional): model pretrained path. Default: None
|
|
||||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
||||||
Default: None
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
img_size=(224, 224),
|
img_size=224,
|
||||||
patch_size=16,
|
patch_size=16,
|
||||||
in_channels=3,
|
in_channels=3,
|
||||||
embed_dim=768,
|
embed_dims=768,
|
||||||
depth=12,
|
num_layers=12,
|
||||||
num_heads=12,
|
num_heads=12,
|
||||||
mlp_ratio=4,
|
mlp_ratio=4,
|
||||||
out_indices=11,
|
out_indices=11,
|
||||||
qkv_bias=True,
|
qkv_bias=True,
|
||||||
qk_scale=None,
|
|
||||||
drop_rate=0.,
|
drop_rate=0.,
|
||||||
attn_drop_rate=0.,
|
attn_drop_rate=0.,
|
||||||
drop_path_rate=0.,
|
drop_path_rate=0.,
|
||||||
norm_cfg=dict(type='LN', eps=1e-6, requires_grad=True),
|
|
||||||
act_cfg=dict(type='GELU'),
|
|
||||||
norm_eval=False,
|
|
||||||
final_norm=False,
|
|
||||||
out_shape='NCHW',
|
|
||||||
with_cls_token=True,
|
with_cls_token=True,
|
||||||
|
norm_cfg=dict(type='LN'),
|
||||||
|
act_cfg=dict(type='GELU'),
|
||||||
|
final_norm=False,
|
||||||
interpolate_mode='bicubic',
|
interpolate_mode='bicubic',
|
||||||
|
num_fcs=2,
|
||||||
|
norm_eval=False,
|
||||||
with_cp=False,
|
with_cp=False,
|
||||||
pretrained=None,
|
pretrain_style='timm'):
|
||||||
init_cfg=None):
|
super(VisionTransformer, self).__init__()
|
||||||
super(VisionTransformer, self).__init__(init_cfg)
|
|
||||||
self.pretrained = pretrained
|
|
||||||
|
|
||||||
|
if isinstance(img_size, int):
|
||||||
|
img_size = to_2tuple(img_size)
|
||||||
|
elif isinstance(img_size, tuple):
|
||||||
|
if len(img_size) == 1:
|
||||||
|
img_size = to_2tuple(img_size[0])
|
||||||
|
assert len(img_size) == 2, \
|
||||||
|
f'The size of image should have length 1 or 2, ' \
|
||||||
|
f'but got {len(img_size)}'
|
||||||
|
|
||||||
|
assert pretrain_style in ['timm', 'mmcls']
|
||||||
|
|
||||||
|
self.pretrain_style = pretrain_style
|
||||||
self.img_size = img_size
|
self.img_size = img_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.features = self.embed_dim = embed_dim
|
|
||||||
self.patch_embed = PatchEmbed(
|
self.patch_embed = PatchEmbed(
|
||||||
img_size=img_size,
|
img_size=img_size,
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
embed_dim=embed_dim)
|
embed_dim=embed_dims,
|
||||||
|
norm_cfg=norm_cfg)
|
||||||
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
self.with_cls_token = with_cls_token
|
self.with_cls_token = with_cls_token
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||||
self.pos_embed = nn.Parameter(
|
self.pos_embed = nn.Parameter(
|
||||||
torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
|
torch.zeros(1, num_patches + 1, embed_dims))
|
||||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
if isinstance(out_indices, int):
|
if isinstance(out_indices, int):
|
||||||
self.out_indices = [out_indices]
|
self.out_indices = [out_indices]
|
||||||
@ -297,37 +260,41 @@ class VisionTransformer(BaseModule):
|
|||||||
else:
|
else:
|
||||||
raise TypeError('out_indices must be type of int, list or tuple')
|
raise TypeError('out_indices must be type of int, list or tuple')
|
||||||
|
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
dpr = [
|
||||||
] # stochastic depth decay rule
|
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
|
||||||
self.blocks = nn.ModuleList([
|
] # stochastic depth decay rule
|
||||||
Block(
|
|
||||||
dim=embed_dim,
|
|
||||||
num_heads=num_heads,
|
|
||||||
mlp_ratio=mlp_ratio,
|
|
||||||
qkv_bias=qkv_bias,
|
|
||||||
qk_scale=qk_scale,
|
|
||||||
drop=dpr[i],
|
|
||||||
attn_drop=attn_drop_rate,
|
|
||||||
act_cfg=act_cfg,
|
|
||||||
norm_cfg=norm_cfg,
|
|
||||||
with_cp=with_cp) for i in range(depth)
|
|
||||||
])
|
|
||||||
|
|
||||||
assert out_shape in ['NLC',
|
self.layers = ModuleList()
|
||||||
'NCHW'], 'output shape must be "NLC" or "NCHW".'
|
for i in range(num_layers):
|
||||||
|
self.layers.append(
|
||||||
self.out_shape = out_shape
|
TransformerEncoderLayer(
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
num_heads=num_heads,
|
||||||
|
feedforward_channels=mlp_ratio * embed_dims,
|
||||||
|
attn_drop_rate=attn_drop_rate,
|
||||||
|
drop_rate=drop_rate,
|
||||||
|
drop_path_rate=dpr[i],
|
||||||
|
num_fcs=num_fcs,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
act_cfg=act_cfg,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
batch_first=True))
|
||||||
|
|
||||||
self.interpolate_mode = interpolate_mode
|
self.interpolate_mode = interpolate_mode
|
||||||
self.final_norm = final_norm
|
self.final_norm = final_norm
|
||||||
if final_norm:
|
if final_norm:
|
||||||
_, self.norm = build_norm_layer(norm_cfg, embed_dim)
|
self.norm1_name, norm1 = build_norm_layer(
|
||||||
|
norm_cfg, embed_dims, postfix=1)
|
||||||
|
self.add_module(self.norm1_name, norm1)
|
||||||
|
|
||||||
self.norm_eval = norm_eval
|
self.norm_eval = norm_eval
|
||||||
self.with_cp = with_cp
|
self.with_cp = with_cp
|
||||||
|
|
||||||
def init_weights(self):
|
@property
|
||||||
pretrained = self.pretrained
|
def norm1(self):
|
||||||
|
return getattr(self, self.norm1_name)
|
||||||
|
|
||||||
|
def init_weights(self, pretrained=None):
|
||||||
if isinstance(pretrained, str):
|
if isinstance(pretrained, str):
|
||||||
logger = get_root_logger()
|
logger = get_root_logger()
|
||||||
checkpoint = _load_checkpoint(pretrained, logger=logger)
|
checkpoint = _load_checkpoint(pretrained, logger=logger)
|
||||||
@ -338,10 +305,17 @@ class VisionTransformer(BaseModule):
|
|||||||
else:
|
else:
|
||||||
state_dict = checkpoint
|
state_dict = checkpoint
|
||||||
|
|
||||||
|
if self.pretrain_style == 'timm':
|
||||||
|
# Because the refactor of vit is blocked by mmcls,
|
||||||
|
# so we firstly use timm pretrain weights to train
|
||||||
|
# downstream model.
|
||||||
|
state_dict = vit_convert(state_dict)
|
||||||
|
|
||||||
if 'pos_embed' in state_dict.keys():
|
if 'pos_embed' in state_dict.keys():
|
||||||
if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
||||||
logger.info(msg=f'Resize the pos_embed shape from \
|
logger.info(msg=f'Resize the pos_embed shape from '
|
||||||
{state_dict["pos_embed"].shape} to {self.pos_embed.shape}')
|
f'{state_dict["pos_embed"].shape} to '
|
||||||
|
f'{self.pos_embed.shape}')
|
||||||
h, w = self.img_size
|
h, w = self.img_size
|
||||||
pos_size = int(
|
pos_size = int(
|
||||||
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
|
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
|
||||||
@ -354,17 +328,17 @@ class VisionTransformer(BaseModule):
|
|||||||
elif pretrained is None:
|
elif pretrained is None:
|
||||||
# We only implement the 'jax_impl' initialization implemented at
|
# We only implement the 'jax_impl' initialization implemented at
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||||
trunc_normal_(self.pos_embed, std=.02)
|
trunc_normal_init(self.pos_embed, std=.02)
|
||||||
trunc_normal_(self.cls_token, std=.02)
|
trunc_normal_init(self.cls_token, std=.02)
|
||||||
for n, m in self.named_modules():
|
for n, m in self.named_modules():
|
||||||
if isinstance(m, Linear):
|
if isinstance(m, nn.Linear):
|
||||||
trunc_normal_(m.weight, std=.02)
|
trunc_normal_init(m.weight, std=.02)
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
if 'mlp' in n:
|
if 'ffn' in n:
|
||||||
normal_init(m.bias, std=1e-6)
|
normal_init(m.bias, std=1e-6)
|
||||||
else:
|
else:
|
||||||
constant_init(m.bias, 0)
|
constant_init(m.bias, 0)
|
||||||
elif isinstance(m, Conv2d):
|
elif isinstance(m, nn.Conv2d):
|
||||||
kaiming_init(m.weight, mode='fan_in')
|
kaiming_init(m.weight, mode='fan_in')
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
constant_init(m.bias, 0)
|
constant_init(m.bias, 0)
|
||||||
@ -404,7 +378,7 @@ class VisionTransformer(BaseModule):
|
|||||||
pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:],
|
pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:],
|
||||||
(pos_h, pos_w), self.patch_size,
|
(pos_h, pos_w), self.patch_size,
|
||||||
self.interpolate_mode)
|
self.interpolate_mode)
|
||||||
return self.pos_drop(patched_img + pos_embed)
|
return self.drop_after_pos(patched_img + pos_embed)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
|
def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
|
||||||
@ -441,31 +415,31 @@ class VisionTransformer(BaseModule):
|
|||||||
|
|
||||||
x = self.patch_embed(inputs)
|
x = self.patch_embed(inputs)
|
||||||
|
|
||||||
|
# stole cls_tokens impl from Phil Wang, thanks
|
||||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||||
x = torch.cat((cls_tokens, x), dim=1)
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
x = self._pos_embeding(inputs, x, self.pos_embed)
|
x = self._pos_embeding(inputs, x, self.pos_embed)
|
||||||
|
|
||||||
if not self.with_cls_token:
|
if not self.with_cls_token:
|
||||||
# Remove class token for transformer input
|
# Remove class token for transformer encoder input
|
||||||
x = x[:, 1:]
|
x = x[:, 1:]
|
||||||
|
|
||||||
outs = []
|
outs = []
|
||||||
for i, blk in enumerate(self.blocks):
|
for i, layer in enumerate(self.layers):
|
||||||
x = blk(x)
|
x = layer(x)
|
||||||
if i == len(self.blocks) - 1:
|
if i == len(self.layers) - 1:
|
||||||
if self.final_norm:
|
if self.final_norm:
|
||||||
x = self.norm(x)
|
x = self.norm1(x)
|
||||||
if i in self.out_indices:
|
if i in self.out_indices:
|
||||||
if self.with_cls_token:
|
if self.with_cls_token:
|
||||||
# Remove class token and reshape token for decoder head
|
# Remove class token and reshape token for decoder head
|
||||||
out = x[:, 1:]
|
out = x[:, 1:]
|
||||||
else:
|
else:
|
||||||
out = x
|
out = x
|
||||||
if self.out_shape == 'NCHW':
|
B, _, C = out.shape
|
||||||
B, _, C = out.shape
|
out = out.reshape(B, inputs.shape[2] // self.patch_size,
|
||||||
out = out.reshape(B, inputs.shape[2] // self.patch_size,
|
inputs.shape[3] // self.patch_size,
|
||||||
inputs.shape[3] // self.patch_size,
|
C).permute(0, 3, 1, 2)
|
||||||
C).permute(0, 3, 1, 2)
|
|
||||||
outs.append(out)
|
outs.append(out)
|
||||||
|
|
||||||
return tuple(outs)
|
return tuple(outs)
|
||||||
|
@ -4,10 +4,10 @@ from .make_divisible import make_divisible
|
|||||||
from .res_layer import ResLayer
|
from .res_layer import ResLayer
|
||||||
from .se_layer import SELayer
|
from .se_layer import SELayer
|
||||||
from .self_attention_block import SelfAttentionBlock
|
from .self_attention_block import SelfAttentionBlock
|
||||||
|
from .timm_convert import vit_convert
|
||||||
from .up_conv_block import UpConvBlock
|
from .up_conv_block import UpConvBlock
|
||||||
from .weight_init import trunc_normal_
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
|
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
|
||||||
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath', 'trunc_normal_'
|
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath', 'vit_convert'
|
||||||
]
|
]
|
||||||
|
33
mmseg/models/utils/timm_convert.py
Normal file
33
mmseg/models/utils/timm_convert.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
|
def vit_convert(timm_dict):
|
||||||
|
|
||||||
|
mmseg_dict = OrderedDict()
|
||||||
|
|
||||||
|
for k, v in timm_dict.items():
|
||||||
|
if k.startswith('head'):
|
||||||
|
continue
|
||||||
|
if k.startswith('norm'):
|
||||||
|
new_k = k.replace('norm.', 'ln1.')
|
||||||
|
elif k.startswith('patch_embed'):
|
||||||
|
if 'proj' in k:
|
||||||
|
new_k = k.replace('proj', 'projection')
|
||||||
|
elif k.startswith('blocks'):
|
||||||
|
new_k = k.replace('blocks.', 'layers.')
|
||||||
|
if 'norm' in new_k:
|
||||||
|
new_k = new_k.replace('norm', 'ln')
|
||||||
|
elif 'mlp.fc1' in new_k:
|
||||||
|
new_k = new_k.replace('mlp.fc1', 'ffn.layers.0.0')
|
||||||
|
elif 'mlp.fc2' in new_k:
|
||||||
|
new_k = new_k.replace('mlp.fc2', 'ffn.layers.1')
|
||||||
|
elif 'attn.qkv' in new_k:
|
||||||
|
new_k = new_k.replace('attn.qkv.', 'attn.attn.in_proj_')
|
||||||
|
elif 'attn.proj' in new_k:
|
||||||
|
new_k = new_k.replace('attn.proj', 'attn.attn.out_proj')
|
||||||
|
else:
|
||||||
|
new_k = k
|
||||||
|
new_k = f'backbone.{new_k}'
|
||||||
|
mmseg_dict[new_k] = v
|
||||||
|
|
||||||
|
return mmseg_dict
|
@ -1,62 +0,0 @@
|
|||||||
"""Modified from https://github.com/rwightman/pytorch-image-
|
|
||||||
models/blob/master/timm/models/layers/drop.py."""
|
|
||||||
|
|
||||||
import math
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
|
||||||
"""Reference: https://people.sc.fsu.edu/~jburkardt/presentations
|
|
||||||
/truncated_normal.pdf"""
|
|
||||||
|
|
||||||
def norm_cdf(x):
|
|
||||||
# Computes standard normal cumulative distribution function
|
|
||||||
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
|
||||||
|
|
||||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
|
||||||
warnings.warn(
|
|
||||||
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
|
||||||
'The distribution of values may be incorrect.',
|
|
||||||
stacklevel=2)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# Values are generated by using a truncated uniform distribution and
|
|
||||||
# then using the inverse CDF for the normal distribution.
|
|
||||||
# Get upper and lower cdf values
|
|
||||||
lower_bound = norm_cdf((a - mean) / std)
|
|
||||||
upper_bound = norm_cdf((b - mean) / std)
|
|
||||||
|
|
||||||
# Uniformly fill tensor with values from [l, u], then translate to
|
|
||||||
# [2l-1, 2u-1].
|
|
||||||
tensor.uniform_(2 * lower_bound - 1, 2 * upper_bound - 1)
|
|
||||||
|
|
||||||
# Use inverse cdf transform for normal distribution to get truncated
|
|
||||||
# standard normal
|
|
||||||
tensor.erfinv_()
|
|
||||||
|
|
||||||
# Transform to proper mean, std
|
|
||||||
tensor.mul_(std * math.sqrt(2.))
|
|
||||||
tensor.add_(mean)
|
|
||||||
|
|
||||||
# Clamp to ensure it's in the proper range
|
|
||||||
tensor.clamp_(min=a, max=b)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
|
||||||
r"""Fills the input Tensor with values drawn from a truncated
|
|
||||||
normal distribution. The values are effectively drawn from the
|
|
||||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
|
||||||
with values outside :math:`[a, b]` redrawn until they are within
|
|
||||||
the bounds. The method used for generating the random values works
|
|
||||||
best when :math:`a \leq \text{mean} \leq b`.
|
|
||||||
Args:
|
|
||||||
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`
|
|
||||||
mean (float): the mean of the normal distribution
|
|
||||||
std (float): the standard deviation of the normal distribution
|
|
||||||
a (float): the minimum cutoff value
|
|
||||||
b (float): the maximum cutoff value
|
|
||||||
"""
|
|
||||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
|
@ -24,19 +24,18 @@ def test_vit_backbone():
|
|||||||
x = torch.randn(1, 196)
|
x = torch.randn(1, 196)
|
||||||
VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
|
VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
|
||||||
|
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(ValueError):
|
||||||
# forward inputs must be [N, C, H, W]
|
# forward inputs must be [N, C, H, W]
|
||||||
x = torch.randn(3, 30, 30)
|
x = torch.randn(3, 30, 30)
|
||||||
model = VisionTransformer()
|
model = VisionTransformer()
|
||||||
model(x)
|
model(x)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# out_shape must be 'NLC' or 'NCHW;'
|
VisionTransformer(img_size=(224, 224, 224))
|
||||||
VisionTransformer(out_shape='NCL')
|
|
||||||
|
|
||||||
# Test img_size isinstance int
|
# Test img_size isinstance tuple
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
model = VisionTransformer(img_size=224)
|
model = VisionTransformer(img_size=(224, 224))
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model(imgs)
|
model(imgs)
|
||||||
|
|
||||||
@ -65,6 +64,11 @@ def test_vit_backbone():
|
|||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert feat[-1].shape == (1, 768, 14, 14)
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
|
||||||
|
# Test unbalanced size input image
|
||||||
|
imgs = torch.randn(1, 3, 112, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[-1].shape == (1, 768, 7, 14)
|
||||||
|
|
||||||
# Test with_cp=True
|
# Test with_cp=True
|
||||||
model = VisionTransformer(with_cp=True)
|
model = VisionTransformer(with_cp=True)
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
@ -77,8 +81,8 @@ def test_vit_backbone():
|
|||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert feat[-1].shape == (1, 768, 14, 14)
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
|
||||||
# Test final reshape arg
|
# Test final norm
|
||||||
|
model = VisionTransformer(final_norm=True)
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
model = VisionTransformer(out_shape='NLC')
|
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert feat[-1].shape == (1, 196, 768)
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user