mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Adjust vision transformer backbone architectures (#524)
* Adjust vision transformer backbone architectures; * Add DropPath, trunc_normal_ for VisionTransformer implementation; * Add class token buring intermediate period and remove it during final period; * Fix some parameters loss bug; * * Store intermediate token features and impose no processes on them; * Remove class token and reshape entire token feature from NLC to NCHW; * Fix some doc error * Add a arg for VisionTransformer backbone to control if input class token into transformer; * Add stochastic depth decay rule for DropPath; * * Fix output bug when input_cls_token=False; * Add related unit test; * * Add arg: out_indices to control model output; * Add unit test for DropPath; * Apply suggestions from code review Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>
This commit is contained in:
parent
771ca7d3e0
commit
c27ef91942
@ -8,12 +8,13 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint as cp
|
import torch.utils.checkpoint as cp
|
||||||
from mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer,
|
from mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer,
|
||||||
constant_init, kaiming_init, normal_init, xavier_init)
|
constant_init, kaiming_init, normal_init)
|
||||||
from mmcv.runner import _load_checkpoint
|
from mmcv.runner import _load_checkpoint
|
||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
|
|
||||||
from mmseg.utils import get_root_logger
|
from mmseg.utils import get_root_logger
|
||||||
from ..builder import BACKBONES
|
from ..builder import BACKBONES
|
||||||
|
from ..utils import DropPath, trunc_normal_
|
||||||
|
|
||||||
|
|
||||||
class Mlp(nn.Module):
|
class Mlp(nn.Module):
|
||||||
@ -114,10 +115,14 @@ class Block(nn.Module):
|
|||||||
Default: 0.
|
Default: 0.
|
||||||
proj_drop (float): Drop rate for attn layer output weights.
|
proj_drop (float): Drop rate for attn layer output weights.
|
||||||
Default: 0.
|
Default: 0.
|
||||||
|
drop_path (float): Drop rate for paths of model.
|
||||||
|
Default: 0.
|
||||||
act_cfg (dict): Config dict for activation layer.
|
act_cfg (dict): Config dict for activation layer.
|
||||||
Default: dict(type='GELU').
|
Default: dict(type='GELU').
|
||||||
norm_cfg (dict): Config dict for normalization layer.
|
norm_cfg (dict): Config dict for normalization layer.
|
||||||
Default: dict(type='LN', requires_grad=True).
|
Default: dict(type='LN', requires_grad=True).
|
||||||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||||
|
memory while slowing down the training speed. Default: False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -129,14 +134,17 @@ class Block(nn.Module):
|
|||||||
drop=0.,
|
drop=0.,
|
||||||
attn_drop=0.,
|
attn_drop=0.,
|
||||||
proj_drop=0.,
|
proj_drop=0.,
|
||||||
|
drop_path=0.,
|
||||||
act_cfg=dict(type='GELU'),
|
act_cfg=dict(type='GELU'),
|
||||||
norm_cfg=dict(type='LN'),
|
norm_cfg=dict(type='LN', eps=1e-6),
|
||||||
with_cp=False):
|
with_cp=False):
|
||||||
super(Block, self).__init__()
|
super(Block, self).__init__()
|
||||||
self.with_cp = with_cp
|
self.with_cp = with_cp
|
||||||
_, self.norm1 = build_norm_layer(norm_cfg, dim)
|
_, self.norm1 = build_norm_layer(norm_cfg, dim)
|
||||||
self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop,
|
self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop,
|
||||||
proj_drop)
|
proj_drop)
|
||||||
|
self.drop_path = DropPath(
|
||||||
|
drop_path) if drop_path > 0. else nn.Identity()
|
||||||
_, self.norm2 = build_norm_layer(norm_cfg, dim)
|
_, self.norm2 = build_norm_layer(norm_cfg, dim)
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
self.mlp = Mlp(
|
self.mlp = Mlp(
|
||||||
@ -148,8 +156,8 @@ class Block(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
||||||
def _inner_forward(x):
|
def _inner_forward(x):
|
||||||
out = x + self.attn(self.norm1(x))
|
out = x + self.drop_path(self.attn(self.norm1(x)))
|
||||||
out = out + self.mlp(self.norm2(out))
|
out = out + self.drop_path(self.mlp(self.norm2(out)))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
if self.with_cp and x.requires_grad:
|
if self.with_cp and x.requires_grad:
|
||||||
@ -164,7 +172,7 @@ class PatchEmbed(nn.Module):
|
|||||||
"""Image to Patch Embedding.
|
"""Image to Patch Embedding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img_size (int, tuple): Input image size.
|
img_size (int | tuple): Input image size.
|
||||||
default: 224.
|
default: 224.
|
||||||
patch_size (int): Width and height for a patch.
|
patch_size (int): Width and height for a patch.
|
||||||
default: 16.
|
default: 16.
|
||||||
@ -202,24 +210,34 @@ class VisionTransformer(nn.Module):
|
|||||||
Image Recognition at Scale` - https://arxiv.org/abs/2010.11929
|
Image Recognition at Scale` - https://arxiv.org/abs/2010.11929
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img_size (tuple): input image size. Default: (224, 224).
|
img_size (tuple): input image size. Default: (224, 224).
|
||||||
patch_size (int, tuple): patch size. Default: 16.
|
patch_size (int, tuple): patch size. Default: 16.
|
||||||
in_channels (int): number of input channels. Default: 3.
|
in_channels (int): number of input channels. Default: 3.
|
||||||
embed_dim (int): embedding dimension. Default: 768.
|
embed_dim (int): embedding dimension. Default: 768.
|
||||||
depth (int): depth of transformer. Default: 12.
|
depth (int): depth of transformer. Default: 12.
|
||||||
num_heads (int): number of attention heads. Default: 12.
|
num_heads (int): number of attention heads. Default: 12.
|
||||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim. Default: 4.
|
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||||
|
Default: 4.
|
||||||
|
out_indices (list | tuple | int): Output from which stages.
|
||||||
|
Default: -1.
|
||||||
qkv_bias (bool): enable bias for qkv if True. Default: True.
|
qkv_bias (bool): enable bias for qkv if True. Default: True.
|
||||||
qk_scale (float): override default qk scale of head_dim ** -0.5 if set.
|
qk_scale (float): override default qk scale of head_dim ** -0.5 if set.
|
||||||
drop_rate (float): dropout rate. Default: 0.
|
drop_rate (float): dropout rate. Default: 0.
|
||||||
attn_drop_rate (float): attention dropout rate. Default: 0.
|
attn_drop_rate (float): attention dropout rate. Default: 0.
|
||||||
|
drop_path_rate (float): Rate of DropPath. Default: 0.
|
||||||
norm_cfg (dict): Config dict for normalization layer.
|
norm_cfg (dict): Config dict for normalization layer.
|
||||||
Default: dict(type='LN', requires_grad=True).
|
Default: dict(type='LN', eps=1e-6, requires_grad=True).
|
||||||
act_cfg (dict): Config dict for activation layer.
|
act_cfg (dict): Config dict for activation layer.
|
||||||
Default: dict(type='GELU').
|
Default: dict(type='GELU').
|
||||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||||
and its variants only. Default: False.
|
and its variants only. Default: False.
|
||||||
|
final_norm (bool): Whether to add a additional layer to normalize
|
||||||
|
final feature map. Default: False.
|
||||||
|
interpolate_mode (str): Select the interpolate mode for position
|
||||||
|
embeding vector resize. Default: bicubic.
|
||||||
|
with_cls_token (bool): If concatenating class token into image tokens
|
||||||
|
as transformer input. Default: True.
|
||||||
with_cp (bool): Use checkpoint or not. Using checkpoint
|
with_cp (bool): Use checkpoint or not. Using checkpoint
|
||||||
will save some memory while slowing down the training speed.
|
will save some memory while slowing down the training speed.
|
||||||
Default: False.
|
Default: False.
|
||||||
@ -233,13 +251,18 @@ class VisionTransformer(nn.Module):
|
|||||||
depth=12,
|
depth=12,
|
||||||
num_heads=12,
|
num_heads=12,
|
||||||
mlp_ratio=4,
|
mlp_ratio=4,
|
||||||
|
out_indices=11,
|
||||||
qkv_bias=True,
|
qkv_bias=True,
|
||||||
qk_scale=None,
|
qk_scale=None,
|
||||||
drop_rate=0.,
|
drop_rate=0.,
|
||||||
attn_drop_rate=0.,
|
attn_drop_rate=0.,
|
||||||
norm_cfg=dict(type='LN'),
|
drop_path_rate=0.,
|
||||||
|
norm_cfg=dict(type='LN', eps=1e-6, requires_grad=True),
|
||||||
act_cfg=dict(type='GELU'),
|
act_cfg=dict(type='GELU'),
|
||||||
norm_eval=False,
|
norm_eval=False,
|
||||||
|
final_norm=False,
|
||||||
|
with_cls_token=True,
|
||||||
|
interpolate_mode='bicubic',
|
||||||
with_cp=False):
|
with_cp=False):
|
||||||
super(VisionTransformer, self).__init__()
|
super(VisionTransformer, self).__init__()
|
||||||
self.img_size = img_size
|
self.img_size = img_size
|
||||||
@ -251,23 +274,38 @@ class VisionTransformer(nn.Module):
|
|||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
embed_dim=embed_dim)
|
embed_dim=embed_dim)
|
||||||
|
|
||||||
|
self.with_cls_token = with_cls_token
|
||||||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
||||||
self.pos_embed = nn.Parameter(
|
self.pos_embed = nn.Parameter(
|
||||||
torch.zeros(1, self.patch_embed.num_patches, embed_dim))
|
torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
|
||||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
self.blocks = nn.Sequential(*[
|
if isinstance(out_indices, int):
|
||||||
|
self.out_indices = [out_indices]
|
||||||
|
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
|
||||||
|
self.out_indices = out_indices
|
||||||
|
else:
|
||||||
|
raise TypeError('out_indices must be type of int, list or tuple')
|
||||||
|
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
||||||
|
] # stochastic depth decay rule
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
Block(
|
Block(
|
||||||
dim=embed_dim,
|
dim=embed_dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
qk_scale=qk_scale,
|
qk_scale=qk_scale,
|
||||||
drop=drop_rate,
|
drop=dpr[i],
|
||||||
attn_drop=attn_drop_rate,
|
attn_drop=attn_drop_rate,
|
||||||
act_cfg=act_cfg,
|
act_cfg=act_cfg,
|
||||||
norm_cfg=norm_cfg,
|
norm_cfg=norm_cfg,
|
||||||
with_cp=with_cp) for i in range(depth)
|
with_cp=with_cp) for i in range(depth)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
self.interpolate_mode = interpolate_mode
|
||||||
|
self.final_norm = final_norm
|
||||||
|
if final_norm:
|
||||||
_, self.norm = build_norm_layer(norm_cfg, embed_dim)
|
_, self.norm = build_norm_layer(norm_cfg, embed_dim)
|
||||||
|
|
||||||
self.norm_eval = norm_eval
|
self.norm_eval = norm_eval
|
||||||
@ -283,28 +321,26 @@ class VisionTransformer(nn.Module):
|
|||||||
state_dict = checkpoint
|
state_dict = checkpoint
|
||||||
|
|
||||||
if 'pos_embed' in state_dict.keys():
|
if 'pos_embed' in state_dict.keys():
|
||||||
state_dict['pos_embed'] = state_dict['pos_embed'][:, 1:, :]
|
|
||||||
logger.info(
|
|
||||||
msg='Remove the "cls_token" dimension from the checkpoint')
|
|
||||||
|
|
||||||
if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
||||||
logger.info(msg=f'Resize the pos_embed shape from \
|
logger.info(msg=f'Resize the pos_embed shape from \
|
||||||
{state_dict["pos_embed"].shape} to \
|
{state_dict["pos_embed"].shape} to {self.pos_embed.shape}')
|
||||||
{self.pos_embed.shape}')
|
|
||||||
h, w = self.img_size
|
h, w = self.img_size
|
||||||
pos_size = int(math.sqrt(state_dict['pos_embed'].shape[1]))
|
pos_size = int(
|
||||||
|
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
|
||||||
state_dict['pos_embed'] = self.resize_pos_embed(
|
state_dict['pos_embed'] = self.resize_pos_embed(
|
||||||
state_dict['pos_embed'], (h, w), (pos_size, pos_size),
|
state_dict['pos_embed'], (h, w), (pos_size, pos_size),
|
||||||
self.patch_size)
|
self.patch_size, self.interpolate_mode)
|
||||||
|
|
||||||
self.load_state_dict(state_dict, False)
|
self.load_state_dict(state_dict, False)
|
||||||
|
|
||||||
elif pretrained is None:
|
elif pretrained is None:
|
||||||
# We only implement the 'jax_impl' initialization implemented at
|
# We only implement the 'jax_impl' initialization implemented at
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||||
normal_init(self.pos_embed)
|
trunc_normal_(self.pos_embed, std=.02)
|
||||||
|
trunc_normal_(self.cls_token, std=.02)
|
||||||
for n, m in self.named_modules():
|
for n, m in self.named_modules():
|
||||||
if isinstance(m, Linear):
|
if isinstance(m, Linear):
|
||||||
xavier_init(m.weight, distribution='uniform')
|
trunc_normal_(m.weight, std=.02)
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
if 'mlp' in n:
|
if 'mlp' in n:
|
||||||
normal_init(m.bias, std=1e-6)
|
normal_init(m.bias, std=1e-6)
|
||||||
@ -316,7 +352,7 @@ class VisionTransformer(nn.Module):
|
|||||||
constant_init(m.bias, 0)
|
constant_init(m.bias, 0)
|
||||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||||
constant_init(m.bias, 0)
|
constant_init(m.bias, 0)
|
||||||
constant_init(m.weight, 1)
|
constant_init(m.weight, 1.0)
|
||||||
else:
|
else:
|
||||||
raise TypeError('pretrained must be a str or None')
|
raise TypeError('pretrained must be a str or None')
|
||||||
|
|
||||||
@ -340,7 +376,7 @@ class VisionTransformer(nn.Module):
|
|||||||
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
|
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
|
||||||
if x_len != pos_len:
|
if x_len != pos_len:
|
||||||
if pos_len == (self.img_size[0] // self.patch_size) * (
|
if pos_len == (self.img_size[0] // self.patch_size) * (
|
||||||
self.img_size[1] // self.patch_size):
|
self.img_size[1] // self.patch_size) + 1:
|
||||||
pos_h = self.img_size[0] // self.patch_size
|
pos_h = self.img_size[0] // self.patch_size
|
||||||
pos_w = self.img_size[1] // self.patch_size
|
pos_w = self.img_size[1] // self.patch_size
|
||||||
else:
|
else:
|
||||||
@ -348,11 +384,12 @@ class VisionTransformer(nn.Module):
|
|||||||
'Unexpected shape of pos_embed, got {}.'.format(
|
'Unexpected shape of pos_embed, got {}.'.format(
|
||||||
pos_embed.shape))
|
pos_embed.shape))
|
||||||
pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:],
|
pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:],
|
||||||
(pos_h, pos_w), self.patch_size)
|
(pos_h, pos_w), self.patch_size,
|
||||||
return patched_img + pos_embed
|
self.interpolate_mode)
|
||||||
|
return self.pos_drop(patched_img + pos_embed)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size):
|
def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
|
||||||
"""Resize pos_embed weights.
|
"""Resize pos_embed weights.
|
||||||
|
|
||||||
Resize pos_embed using bicubic interpolate method.
|
Resize pos_embed using bicubic interpolate method.
|
||||||
@ -367,26 +404,52 @@ class VisionTransformer(nn.Module):
|
|||||||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
||||||
input_h, input_w = input_shpae
|
input_h, input_w = input_shpae
|
||||||
pos_h, pos_w = pos_shape
|
pos_h, pos_w = pos_shape
|
||||||
pos_embed = pos_embed.reshape(1, pos_h, pos_w,
|
cls_token_weight = pos_embed[:, 0]
|
||||||
pos_embed.shape[2]).permute(0, 3, 1, 2)
|
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
||||||
pos_embed = F.interpolate(
|
pos_embed_weight = pos_embed_weight.reshape(
|
||||||
pos_embed,
|
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
||||||
|
pos_embed_weight = F.interpolate(
|
||||||
|
pos_embed_weight,
|
||||||
size=[input_h // patch_size, input_w // patch_size],
|
size=[input_h // patch_size, input_w // patch_size],
|
||||||
align_corners=False,
|
align_corners=False,
|
||||||
mode='bicubic')
|
mode=mode)
|
||||||
pos_embed = torch.flatten(pos_embed, 2).transpose(1, 2)
|
cls_token_weight = cls_token_weight.unsqueeze(1)
|
||||||
|
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
||||||
|
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
||||||
return pos_embed
|
return pos_embed
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
|
B = inputs.shape[0]
|
||||||
|
|
||||||
x = self.patch_embed(inputs)
|
x = self.patch_embed(inputs)
|
||||||
|
|
||||||
|
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
x = self._pos_embeding(inputs, x, self.pos_embed)
|
x = self._pos_embeding(inputs, x, self.pos_embed)
|
||||||
x = self.blocks(x)
|
|
||||||
|
if not self.with_cls_token:
|
||||||
|
# Remove class token for transformer input
|
||||||
|
x = x[:, 1:]
|
||||||
|
|
||||||
|
outs = []
|
||||||
|
for i, blk in enumerate(self.blocks):
|
||||||
|
x = blk(x)
|
||||||
|
if i == len(self.blocks) - 1:
|
||||||
|
if self.final_norm:
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
B, _, C = x.shape
|
if i in self.out_indices:
|
||||||
x = x.reshape(B, inputs.shape[2] // self.patch_size,
|
if self.with_cls_token:
|
||||||
|
# Remove class token and reshape token for decoder head
|
||||||
|
out = x[:, 1:]
|
||||||
|
else:
|
||||||
|
out = x
|
||||||
|
B, _, C = out.shape
|
||||||
|
out = out.reshape(B, inputs.shape[2] // self.patch_size,
|
||||||
inputs.shape[3] // self.patch_size,
|
inputs.shape[3] // self.patch_size,
|
||||||
C).permute(0, 3, 1, 2)
|
C).permute(0, 3, 1, 2)
|
||||||
return [x]
|
outs.append(out)
|
||||||
|
|
||||||
|
return tuple(outs)
|
||||||
|
|
||||||
def train(self, mode=True):
|
def train(self, mode=True):
|
||||||
super(VisionTransformer, self).train(mode)
|
super(VisionTransformer, self).train(mode)
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
|
from .drop import DropPath
|
||||||
from .inverted_residual import InvertedResidual, InvertedResidualV3
|
from .inverted_residual import InvertedResidual, InvertedResidualV3
|
||||||
from .make_divisible import make_divisible
|
from .make_divisible import make_divisible
|
||||||
from .res_layer import ResLayer
|
from .res_layer import ResLayer
|
||||||
from .se_layer import SELayer
|
from .se_layer import SELayer
|
||||||
from .self_attention_block import SelfAttentionBlock
|
from .self_attention_block import SelfAttentionBlock
|
||||||
from .up_conv_block import UpConvBlock
|
from .up_conv_block import UpConvBlock
|
||||||
|
from .weight_init import trunc_normal_
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
|
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
|
||||||
'UpConvBlock', 'InvertedResidualV3', 'SELayer'
|
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath', 'trunc_normal_'
|
||||||
]
|
]
|
||||||
|
31
mmseg/models/utils/drop.py
Normal file
31
mmseg/models/utils/drop.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
"""Modified from https://github.com/rwightman/pytorch-image-
|
||||||
|
models/blob/master/timm/models/layers/drop.py."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class DropPath(nn.Module):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
|
||||||
|
residual blocks).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
drop_prob (float): Drop rate for paths of model. Dropout rate has
|
||||||
|
to be between 0 and 1. Default: 0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, drop_prob=0.):
|
||||||
|
super(DropPath, self).__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
self.keep_prob = 1 - drop_prob
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.drop_prob == 0. or not self.training:
|
||||||
|
return x
|
||||||
|
shape = (x.shape[0], ) + (1, ) * (
|
||||||
|
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
|
random_tensor = self.keep_prob + torch.rand(
|
||||||
|
shape, dtype=x.dtype, device=x.device)
|
||||||
|
random_tensor.floor_() # binarize
|
||||||
|
output = x.div(self.keep_prob) * random_tensor
|
||||||
|
return output
|
62
mmseg/models/utils/weight_init.py
Normal file
62
mmseg/models/utils/weight_init.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
"""Modified from https://github.com/rwightman/pytorch-image-
|
||||||
|
models/blob/master/timm/models/layers/drop.py."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||||
|
"""Reference: https://people.sc.fsu.edu/~jburkardt/presentations
|
||||||
|
/truncated_normal.pdf"""
|
||||||
|
|
||||||
|
def norm_cdf(x):
|
||||||
|
# Computes standard normal cumulative distribution function
|
||||||
|
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
||||||
|
|
||||||
|
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||||
|
warnings.warn(
|
||||||
|
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
||||||
|
'The distribution of values may be incorrect.',
|
||||||
|
stacklevel=2)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Values are generated by using a truncated uniform distribution and
|
||||||
|
# then using the inverse CDF for the normal distribution.
|
||||||
|
# Get upper and lower cdf values
|
||||||
|
lower_bound = norm_cdf((a - mean) / std)
|
||||||
|
upper_bound = norm_cdf((b - mean) / std)
|
||||||
|
|
||||||
|
# Uniformly fill tensor with values from [l, u], then translate to
|
||||||
|
# [2l-1, 2u-1].
|
||||||
|
tensor.uniform_(2 * lower_bound - 1, 2 * upper_bound - 1)
|
||||||
|
|
||||||
|
# Use inverse cdf transform for normal distribution to get truncated
|
||||||
|
# standard normal
|
||||||
|
tensor.erfinv_()
|
||||||
|
|
||||||
|
# Transform to proper mean, std
|
||||||
|
tensor.mul_(std * math.sqrt(2.))
|
||||||
|
tensor.add_(mean)
|
||||||
|
|
||||||
|
# Clamp to ensure it's in the proper range
|
||||||
|
tensor.clamp_(min=a, max=b)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||||
|
r"""Fills the input Tensor with values drawn from a truncated
|
||||||
|
normal distribution. The values are effectively drawn from the
|
||||||
|
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||||
|
with values outside :math:`[a, b]` redrawn until they are within
|
||||||
|
the bounds. The method used for generating the random values works
|
||||||
|
best when :math:`a \leq \text{mean} \leq b`.
|
||||||
|
Args:
|
||||||
|
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`
|
||||||
|
mean (float): the mean of the normal distribution
|
||||||
|
std (float): the standard deviation of the normal distribution
|
||||||
|
a (float): the minimum cutoff value
|
||||||
|
b (float): the maximum cutoff value
|
||||||
|
"""
|
||||||
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
@ -15,10 +15,14 @@ def test_vit_backbone():
|
|||||||
# img_size must be int or tuple
|
# img_size must be int or tuple
|
||||||
model = VisionTransformer(img_size=512.0)
|
model = VisionTransformer(img_size=512.0)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
# out_indices must be int ,list or tuple
|
||||||
|
model = VisionTransformer(out_indices=1.)
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
# test upsample_pos_embed function
|
# test upsample_pos_embed function
|
||||||
x = torch.randn(1, 196)
|
x = torch.randn(1, 196)
|
||||||
VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224)
|
VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
|
||||||
|
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
# forward inputs must be [N, C, H, W]
|
# forward inputs must be [N, C, H, W]
|
||||||
@ -46,19 +50,25 @@ def test_vit_backbone():
|
|||||||
# Test large size input image
|
# Test large size input image
|
||||||
imgs = torch.randn(1, 3, 256, 256)
|
imgs = torch.randn(1, 3, 256, 256)
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert feat[0].shape == (1, 768, 16, 16)
|
assert feat[-1].shape == (1, 768, 16, 16)
|
||||||
|
|
||||||
# Test small size input image
|
# Test small size input image
|
||||||
imgs = torch.randn(1, 3, 32, 32)
|
imgs = torch.randn(1, 3, 32, 32)
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert feat[0].shape == (1, 768, 2, 2)
|
assert feat[-1].shape == (1, 768, 2, 2)
|
||||||
|
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert feat[0].shape == (1, 768, 14, 14)
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
|
||||||
# Test with_cp=True
|
# Test with_cp=True
|
||||||
model = VisionTransformer(with_cp=True)
|
model = VisionTransformer(with_cp=True)
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert feat[0].shape == (1, 768, 14, 14)
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
|
||||||
|
# Test with_cls_token=False
|
||||||
|
model = VisionTransformer(with_cls_token=False)
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
28
tests/test_models/test_utils/test_drop.py
Normal file
28
tests/test_models/test_utils/test_drop.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from mmseg.models.utils import DropPath
|
||||||
|
|
||||||
|
|
||||||
|
def test_drop_path():
|
||||||
|
|
||||||
|
# zero drop
|
||||||
|
layer = DropPath()
|
||||||
|
|
||||||
|
# input NLC format feature
|
||||||
|
x = torch.randn((1, 16, 32))
|
||||||
|
layer(x)
|
||||||
|
|
||||||
|
# input NLHW format feature
|
||||||
|
x = torch.randn((1, 32, 4, 4))
|
||||||
|
layer(x)
|
||||||
|
|
||||||
|
# non-zero drop
|
||||||
|
layer = DropPath(0.1)
|
||||||
|
|
||||||
|
# input NLC format feature
|
||||||
|
x = torch.randn((1, 16, 32))
|
||||||
|
layer(x)
|
||||||
|
|
||||||
|
# input NLHW format feature
|
||||||
|
x = torch.randn((1, 32, 4, 4))
|
||||||
|
layer(x)
|
Loading…
x
Reference in New Issue
Block a user