diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index bda2a3545..1d730d863 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -8,12 +8,13 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as cp from mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer, - constant_init, kaiming_init, normal_init, xavier_init) + constant_init, kaiming_init, normal_init) from mmcv.runner import _load_checkpoint from mmcv.utils.parrots_wrapper import _BatchNorm from mmseg.utils import get_root_logger from ..builder import BACKBONES +from ..utils import DropPath, trunc_normal_ class Mlp(nn.Module): @@ -114,10 +115,14 @@ class Block(nn.Module): Default: 0. proj_drop (float): Drop rate for attn layer output weights. Default: 0. + drop_path (float): Drop rate for paths of model. + Default: 0. act_cfg (dict): Config dict for activation layer. Default: dict(type='GELU'). norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN', requires_grad=True). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. """ def __init__(self, @@ -129,14 +134,17 @@ class Block(nn.Module): drop=0., attn_drop=0., proj_drop=0., + drop_path=0., act_cfg=dict(type='GELU'), - norm_cfg=dict(type='LN'), + norm_cfg=dict(type='LN', eps=1e-6), with_cp=False): super(Block, self).__init__() self.with_cp = with_cp _, self.norm1 = build_norm_layer(norm_cfg, dim) self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() _, self.norm2 = build_norm_layer(norm_cfg, dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( @@ -148,8 +156,8 @@ class Block(nn.Module): def forward(self, x): def _inner_forward(x): - out = x + self.attn(self.norm1(x)) - out = out + self.mlp(self.norm2(out)) + out = x + self.drop_path(self.attn(self.norm1(x))) + out = out + self.drop_path(self.mlp(self.norm2(out))) return out if self.with_cp and x.requires_grad: @@ -164,7 +172,7 @@ class PatchEmbed(nn.Module): """Image to Patch Embedding. Args: - img_size (int, tuple): Input image size. + img_size (int | tuple): Input image size. default: 224. patch_size (int): Width and height for a patch. default: 16. @@ -202,24 +210,34 @@ class VisionTransformer(nn.Module): Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 Args: - img_size (tuple): input image size. Default: (224, 224). + img_size (tuple): input image size. Default: (224, 224). patch_size (int, tuple): patch size. Default: 16. in_channels (int): number of input channels. Default: 3. embed_dim (int): embedding dimension. Default: 768. depth (int): depth of transformer. Default: 12. num_heads (int): number of attention heads. Default: 12. - mlp_ratio (int): ratio of mlp hidden dim to embedding dim. Default: 4. + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + out_indices (list | tuple | int): Output from which stages. + Default: -1. qkv_bias (bool): enable bias for qkv if True. Default: True. qk_scale (float): override default qk scale of head_dim ** -0.5 if set. drop_rate (float): dropout rate. Default: 0. attn_drop_rate (float): attention dropout rate. Default: 0. + drop_path_rate (float): Rate of DropPath. Default: 0. norm_cfg (dict): Config dict for normalization layer. - Default: dict(type='LN', requires_grad=True). + Default: dict(type='LN', eps=1e-6, requires_grad=True). act_cfg (dict): Config dict for activation layer. Default: dict(type='GELU'). norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Default: False. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Default: bicubic. + with_cls_token (bool): If concatenating class token into image tokens + as transformer input. Default: True. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. @@ -233,13 +251,18 @@ class VisionTransformer(nn.Module): depth=12, num_heads=12, mlp_ratio=4, + out_indices=11, qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., - norm_cfg=dict(type='LN'), + drop_path_rate=0., + norm_cfg=dict(type='LN', eps=1e-6, requires_grad=True), act_cfg=dict(type='GELU'), norm_eval=False, + final_norm=False, + with_cls_token=True, + interpolate_mode='bicubic', with_cp=False): super(VisionTransformer, self).__init__() self.img_size = img_size @@ -251,24 +274,39 @@ class VisionTransformer(nn.Module): in_channels=in_channels, embed_dim=embed_dim) + self.with_cls_token = with_cls_token + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) self.pos_embed = nn.Parameter( - torch.zeros(1, self.patch_embed.num_patches, embed_dim)) + torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) - self.blocks = nn.Sequential(*[ + if isinstance(out_indices, int): + self.out_indices = [out_indices] + elif isinstance(out_indices, list) or isinstance(out_indices, tuple): + self.out_indices = out_indices + else: + raise TypeError('out_indices must be type of int, list or tuple') + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, + drop=dpr[i], attn_drop=attn_drop_rate, act_cfg=act_cfg, norm_cfg=norm_cfg, with_cp=with_cp) for i in range(depth) ]) - _, self.norm = build_norm_layer(norm_cfg, embed_dim) + + self.interpolate_mode = interpolate_mode + self.final_norm = final_norm + if final_norm: + _, self.norm = build_norm_layer(norm_cfg, embed_dim) self.norm_eval = norm_eval self.with_cp = with_cp @@ -283,28 +321,26 @@ class VisionTransformer(nn.Module): state_dict = checkpoint if 'pos_embed' in state_dict.keys(): - state_dict['pos_embed'] = state_dict['pos_embed'][:, 1:, :] - logger.info( - msg='Remove the "cls_token" dimension from the checkpoint') - if self.pos_embed.shape != state_dict['pos_embed'].shape: logger.info(msg=f'Resize the pos_embed shape from \ - {state_dict["pos_embed"].shape} to \ - {self.pos_embed.shape}') +{state_dict["pos_embed"].shape} to {self.pos_embed.shape}') h, w = self.img_size - pos_size = int(math.sqrt(state_dict['pos_embed'].shape[1])) + pos_size = int( + math.sqrt(state_dict['pos_embed'].shape[1] - 1)) state_dict['pos_embed'] = self.resize_pos_embed( state_dict['pos_embed'], (h, w), (pos_size, pos_size), - self.patch_size) + self.patch_size, self.interpolate_mode) + self.load_state_dict(state_dict, False) elif pretrained is None: # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 - normal_init(self.pos_embed) + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) for n, m in self.named_modules(): if isinstance(m, Linear): - xavier_init(m.weight, distribution='uniform') + trunc_normal_(m.weight, std=.02) if m.bias is not None: if 'mlp' in n: normal_init(m.bias, std=1e-6) @@ -316,7 +352,7 @@ class VisionTransformer(nn.Module): constant_init(m.bias, 0) elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): constant_init(m.bias, 0) - constant_init(m.weight, 1) + constant_init(m.weight, 1.0) else: raise TypeError('pretrained must be a str or None') @@ -340,7 +376,7 @@ class VisionTransformer(nn.Module): x_len, pos_len = patched_img.shape[1], pos_embed.shape[1] if x_len != pos_len: if pos_len == (self.img_size[0] // self.patch_size) * ( - self.img_size[1] // self.patch_size): + self.img_size[1] // self.patch_size) + 1: pos_h = self.img_size[0] // self.patch_size pos_w = self.img_size[1] // self.patch_size else: @@ -348,11 +384,12 @@ class VisionTransformer(nn.Module): 'Unexpected shape of pos_embed, got {}.'.format( pos_embed.shape)) pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:], - (pos_h, pos_w), self.patch_size) - return patched_img + pos_embed + (pos_h, pos_w), self.patch_size, + self.interpolate_mode) + return self.pos_drop(patched_img + pos_embed) @staticmethod - def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size): + def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode): """Resize pos_embed weights. Resize pos_embed using bicubic interpolate method. @@ -367,26 +404,52 @@ class VisionTransformer(nn.Module): assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' input_h, input_w = input_shpae pos_h, pos_w = pos_shape - pos_embed = pos_embed.reshape(1, pos_h, pos_w, - pos_embed.shape[2]).permute(0, 3, 1, 2) - pos_embed = F.interpolate( - pos_embed, + cls_token_weight = pos_embed[:, 0] + pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] + pos_embed_weight = pos_embed_weight.reshape( + 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) + pos_embed_weight = F.interpolate( + pos_embed_weight, size=[input_h // patch_size, input_w // patch_size], align_corners=False, - mode='bicubic') - pos_embed = torch.flatten(pos_embed, 2).transpose(1, 2) + mode=mode) + cls_token_weight = cls_token_weight.unsqueeze(1) + pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) + pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) return pos_embed def forward(self, inputs): + B = inputs.shape[0] + x = self.patch_embed(inputs) + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) x = self._pos_embeding(inputs, x, self.pos_embed) - x = self.blocks(x) - x = self.norm(x) - B, _, C = x.shape - x = x.reshape(B, inputs.shape[2] // self.patch_size, - inputs.shape[3] // self.patch_size, - C).permute(0, 3, 1, 2) - return [x] + + if not self.with_cls_token: + # Remove class token for transformer input + x = x[:, 1:] + + outs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if i == len(self.blocks) - 1: + if self.final_norm: + x = self.norm(x) + if i in self.out_indices: + if self.with_cls_token: + # Remove class token and reshape token for decoder head + out = x[:, 1:] + else: + out = x + B, _, C = out.shape + out = out.reshape(B, inputs.shape[2] // self.patch_size, + inputs.shape[3] // self.patch_size, + C).permute(0, 3, 1, 2) + outs.append(out) + + return tuple(outs) def train(self, mode=True): super(VisionTransformer, self).train(mode) diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py index 8f0fc16ff..3d3bdd349 100644 --- a/mmseg/models/utils/__init__.py +++ b/mmseg/models/utils/__init__.py @@ -1,11 +1,13 @@ +from .drop import DropPath from .inverted_residual import InvertedResidual, InvertedResidualV3 from .make_divisible import make_divisible from .res_layer import ResLayer from .se_layer import SELayer from .self_attention_block import SelfAttentionBlock from .up_conv_block import UpConvBlock +from .weight_init import trunc_normal_ __all__ = [ 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', - 'UpConvBlock', 'InvertedResidualV3', 'SELayer' + 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath', 'trunc_normal_' ] diff --git a/mmseg/models/utils/drop.py b/mmseg/models/utils/drop.py new file mode 100644 index 000000000..4520b0ff4 --- /dev/null +++ b/mmseg/models/utils/drop.py @@ -0,0 +1,31 @@ +"""Modified from https://github.com/rwightman/pytorch-image- +models/blob/master/timm/models/layers/drop.py.""" + +import torch +from torch import nn + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of + residual blocks). + + Args: + drop_prob (float): Drop rate for paths of model. Dropout rate has + to be between 0 and 1. Default: 0. + """ + + def __init__(self, drop_prob=0.): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.keep_prob = 1 - drop_prob + + def forward(self, x): + if self.drop_prob == 0. or not self.training: + return x + shape = (x.shape[0], ) + (1, ) * ( + x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = self.keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(self.keep_prob) * random_tensor + return output diff --git a/mmseg/models/utils/weight_init.py b/mmseg/models/utils/weight_init.py new file mode 100644 index 000000000..38141ba3d --- /dev/null +++ b/mmseg/models/utils/weight_init.py @@ -0,0 +1,62 @@ +"""Modified from https://github.com/rwightman/pytorch-image- +models/blob/master/timm/models/layers/drop.py.""" + +import math +import warnings + +import torch + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + """Reference: https://people.sc.fsu.edu/~jburkardt/presentations + /truncated_normal.pdf""" + + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + lower_bound = norm_cdf((a - mean) / std) + upper_bound = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * lower_bound - 1, 2 * upper_bound - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor (``torch.Tensor``): an n-dimensional `torch.Tensor` + mean (float): the mean of the normal distribution + std (float): the standard deviation of the normal distribution + a (float): the minimum cutoff value + b (float): the maximum cutoff value + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/tests/test_models/test_backbones/test_vit.py b/tests/test_models/test_backbones/test_vit.py index 5c5572e43..c36894ec9 100644 --- a/tests/test_models/test_backbones/test_vit.py +++ b/tests/test_models/test_backbones/test_vit.py @@ -15,10 +15,14 @@ def test_vit_backbone(): # img_size must be int or tuple model = VisionTransformer(img_size=512.0) + with pytest.raises(TypeError): + # out_indices must be int ,list or tuple + model = VisionTransformer(out_indices=1.) + with pytest.raises(TypeError): # test upsample_pos_embed function x = torch.randn(1, 196) - VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224) + VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear') with pytest.raises(RuntimeError): # forward inputs must be [N, C, H, W] @@ -46,19 +50,25 @@ def test_vit_backbone(): # Test large size input image imgs = torch.randn(1, 3, 256, 256) feat = model(imgs) - assert feat[0].shape == (1, 768, 16, 16) + assert feat[-1].shape == (1, 768, 16, 16) # Test small size input image imgs = torch.randn(1, 3, 32, 32) feat = model(imgs) - assert feat[0].shape == (1, 768, 2, 2) + assert feat[-1].shape == (1, 768, 2, 2) imgs = torch.randn(1, 3, 224, 224) feat = model(imgs) - assert feat[0].shape == (1, 768, 14, 14) + assert feat[-1].shape == (1, 768, 14, 14) # Test with_cp=True model = VisionTransformer(with_cp=True) imgs = torch.randn(1, 3, 224, 224) feat = model(imgs) - assert feat[0].shape == (1, 768, 14, 14) + assert feat[-1].shape == (1, 768, 14, 14) + + # Test with_cls_token=False + model = VisionTransformer(with_cls_token=False) + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat[-1].shape == (1, 768, 14, 14) diff --git a/tests/test_models/test_utils/test_drop.py b/tests/test_models/test_utils/test_drop.py new file mode 100644 index 000000000..1331af8d0 --- /dev/null +++ b/tests/test_models/test_utils/test_drop.py @@ -0,0 +1,28 @@ +import torch + +from mmseg.models.utils import DropPath + + +def test_drop_path(): + + # zero drop + layer = DropPath() + + # input NLC format feature + x = torch.randn((1, 16, 32)) + layer(x) + + # input NLHW format feature + x = torch.randn((1, 32, 4, 4)) + layer(x) + + # non-zero drop + layer = DropPath(0.1) + + # input NLC format feature + x = torch.randn((1, 16, 32)) + layer(x) + + # input NLHW format feature + x = torch.randn((1, 32, 4, 4)) + layer(x)