mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix some vit init bugs (#609)
* [Fix] Fix vit init bug * Add some vit unit tests * Modify module import * Fix pretrain weights bug * Modify pretrained judge * Add some unit tests to improve code cov * Optimize code * Fix vit unit test
This commit is contained in:
parent
458fc7897f
commit
0c4c3b790d
@ -1,4 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -6,8 +7,7 @@ import torch.nn.functional as F
|
|||||||
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
|
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
|
||||||
kaiming_init, normal_init, trunc_normal_init)
|
kaiming_init, normal_init, trunc_normal_init)
|
||||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||||
from mmcv.runner import _load_checkpoint
|
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
||||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
from torch.nn.modules.utils import _pair as to_2tuple
|
from torch.nn.modules.utils import _pair as to_2tuple
|
||||||
|
|
||||||
@ -140,12 +140,6 @@ class PatchEmbed(BaseModule):
|
|||||||
self.norm = None
|
self.norm = None
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = x.shape
|
|
||||||
# FIXME look at relaxing size constraints
|
|
||||||
# assert H == self.img_size[0] and W == self.img_size[1], \
|
|
||||||
# f"Input image size ({H}*{W}) doesn't " \
|
|
||||||
# f'match model ({self.img_size[0]}*{self.img_size[1]}).'
|
|
||||||
# The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
|
|
||||||
x = self.projection(x).flatten(2).transpose(1, 2)
|
x = self.projection(x).flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
if self.norm is not None:
|
if self.norm is not None:
|
||||||
@ -185,8 +179,12 @@ class VisionTransformer(BaseModule):
|
|||||||
Default: dict(type='LN')
|
Default: dict(type='LN')
|
||||||
act_cfg (dict): The activation config for FFNs.
|
act_cfg (dict): The activation config for FFNs.
|
||||||
Defalut: dict(type='GELU').
|
Defalut: dict(type='GELU').
|
||||||
final_norm (bool): Whether to add a additional layer to normalize
|
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
||||||
|
Default: False.
|
||||||
|
final_norm (bool): Whether to add a additional layer to normalize
|
||||||
final feature map. Default: False.
|
final feature map. Default: False.
|
||||||
|
out_shape (str): Select the output format of feature information.
|
||||||
|
Default: NCHW.
|
||||||
interpolate_mode (str): Select the interpolate mode for position
|
interpolate_mode (str): Select the interpolate mode for position
|
||||||
embeding vector resize. Default: bicubic.
|
embeding vector resize. Default: bicubic.
|
||||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||||
@ -198,6 +196,9 @@ class VisionTransformer(BaseModule):
|
|||||||
some memory while slowing down the training speed. Default: False.
|
some memory while slowing down the training speed. Default: False.
|
||||||
pretrain_style (str): Choose to use timm or mmcls pretrain weights.
|
pretrain_style (str): Choose to use timm or mmcls pretrain weights.
|
||||||
Default: timm.
|
Default: timm.
|
||||||
|
pretrained (str, optional): model pretrained path. Default: None.
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Default: None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -216,12 +217,16 @@ class VisionTransformer(BaseModule):
|
|||||||
with_cls_token=True,
|
with_cls_token=True,
|
||||||
norm_cfg=dict(type='LN'),
|
norm_cfg=dict(type='LN'),
|
||||||
act_cfg=dict(type='GELU'),
|
act_cfg=dict(type='GELU'),
|
||||||
|
patch_norm=False,
|
||||||
final_norm=False,
|
final_norm=False,
|
||||||
|
out_shape='NCHW',
|
||||||
interpolate_mode='bicubic',
|
interpolate_mode='bicubic',
|
||||||
num_fcs=2,
|
num_fcs=2,
|
||||||
norm_eval=False,
|
norm_eval=False,
|
||||||
with_cp=False,
|
with_cp=False,
|
||||||
pretrain_style='timm'):
|
pretrain_style='timm',
|
||||||
|
pretrained=None,
|
||||||
|
init_cfg=None):
|
||||||
super(VisionTransformer, self).__init__()
|
super(VisionTransformer, self).__init__()
|
||||||
|
|
||||||
if isinstance(img_size, int):
|
if isinstance(img_size, int):
|
||||||
@ -235,16 +240,32 @@ class VisionTransformer(BaseModule):
|
|||||||
|
|
||||||
assert pretrain_style in ['timm', 'mmcls']
|
assert pretrain_style in ['timm', 'mmcls']
|
||||||
|
|
||||||
self.pretrain_style = pretrain_style
|
assert out_shape in ['NLC',
|
||||||
|
'NCHW'], 'output shape must be "NLC" or "NCHW".'
|
||||||
|
|
||||||
|
if isinstance(pretrained, str) or pretrained is None:
|
||||||
|
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||||
|
'please use "init_cfg" instead')
|
||||||
|
else:
|
||||||
|
raise TypeError('pretrained must be a str or None')
|
||||||
|
|
||||||
self.img_size = img_size
|
self.img_size = img_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.out_shape = out_shape
|
||||||
|
self.interpolate_mode = interpolate_mode
|
||||||
|
self.norm_eval = norm_eval
|
||||||
|
self.with_cp = with_cp
|
||||||
|
self.pretrain_style = pretrain_style
|
||||||
|
self.pretrained = pretrained
|
||||||
|
self.init_cfg = init_cfg
|
||||||
|
|
||||||
self.patch_embed = PatchEmbed(
|
self.patch_embed = PatchEmbed(
|
||||||
img_size=img_size,
|
img_size=img_size,
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
embed_dim=embed_dims,
|
embed_dim=embed_dims,
|
||||||
norm_cfg=norm_cfg)
|
norm_cfg=norm_cfg if patch_norm else None)
|
||||||
|
|
||||||
num_patches = self.patch_embed.num_patches
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
self.with_cls_token = with_cls_token
|
self.with_cls_token = with_cls_token
|
||||||
@ -280,24 +301,20 @@ class VisionTransformer(BaseModule):
|
|||||||
norm_cfg=norm_cfg,
|
norm_cfg=norm_cfg,
|
||||||
batch_first=True))
|
batch_first=True))
|
||||||
|
|
||||||
self.interpolate_mode = interpolate_mode
|
|
||||||
self.final_norm = final_norm
|
self.final_norm = final_norm
|
||||||
if final_norm:
|
if final_norm:
|
||||||
self.norm1_name, norm1 = build_norm_layer(
|
self.norm1_name, norm1 = build_norm_layer(
|
||||||
norm_cfg, embed_dims, postfix=1)
|
norm_cfg, embed_dims, postfix=1)
|
||||||
self.add_module(self.norm1_name, norm1)
|
self.add_module(self.norm1_name, norm1)
|
||||||
|
|
||||||
self.norm_eval = norm_eval
|
|
||||||
self.with_cp = with_cp
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def norm1(self):
|
def norm1(self):
|
||||||
return getattr(self, self.norm1_name)
|
return getattr(self, self.norm1_name)
|
||||||
|
|
||||||
def init_weights(self, pretrained=None):
|
def init_weights(self):
|
||||||
if isinstance(pretrained, str):
|
if isinstance(self.pretrained, str):
|
||||||
logger = get_root_logger()
|
logger = get_root_logger()
|
||||||
checkpoint = _load_checkpoint(pretrained, logger=logger)
|
checkpoint = _load_checkpoint(self.pretrained, logger=logger)
|
||||||
if 'state_dict' in checkpoint:
|
if 'state_dict' in checkpoint:
|
||||||
state_dict = checkpoint['state_dict']
|
state_dict = checkpoint['state_dict']
|
||||||
elif 'model' in checkpoint:
|
elif 'model' in checkpoint:
|
||||||
@ -325,7 +342,8 @@ class VisionTransformer(BaseModule):
|
|||||||
|
|
||||||
self.load_state_dict(state_dict, False)
|
self.load_state_dict(state_dict, False)
|
||||||
|
|
||||||
elif pretrained is None:
|
elif self.pretrained is None:
|
||||||
|
super(VisionTransformer, self).init_weights()
|
||||||
# We only implement the 'jax_impl' initialization implemented at
|
# We only implement the 'jax_impl' initialization implemented at
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||||
trunc_normal_init(self.pos_embed, std=.02)
|
trunc_normal_init(self.pos_embed, std=.02)
|
||||||
@ -345,8 +363,6 @@ class VisionTransformer(BaseModule):
|
|||||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||||
constant_init(m.bias, 0)
|
constant_init(m.bias, 0)
|
||||||
constant_init(m.weight, 1.0)
|
constant_init(m.weight, 1.0)
|
||||||
else:
|
|
||||||
raise TypeError('pretrained must be a str or None')
|
|
||||||
|
|
||||||
def _pos_embeding(self, img, patched_img, pos_embed):
|
def _pos_embeding(self, img, patched_img, pos_embed):
|
||||||
"""Positiong embeding method.
|
"""Positiong embeding method.
|
||||||
@ -436,10 +452,11 @@ class VisionTransformer(BaseModule):
|
|||||||
out = x[:, 1:]
|
out = x[:, 1:]
|
||||||
else:
|
else:
|
||||||
out = x
|
out = x
|
||||||
B, _, C = out.shape
|
if self.out_shape == 'NCHW':
|
||||||
out = out.reshape(B, inputs.shape[2] // self.patch_size,
|
B, _, C = out.shape
|
||||||
inputs.shape[3] // self.patch_size,
|
out = out.reshape(B, inputs.shape[2] // self.patch_size,
|
||||||
C).permute(0, 3, 1, 2)
|
inputs.shape[3] // self.patch_size,
|
||||||
|
C).permute(0, 3, 1, 2)
|
||||||
outs.append(out)
|
outs.append(out)
|
||||||
|
|
||||||
return tuple(outs)
|
return tuple(outs)
|
||||||
|
@ -27,7 +27,6 @@ def vit_convert(timm_dict):
|
|||||||
new_k = new_k.replace('attn.proj', 'attn.attn.out_proj')
|
new_k = new_k.replace('attn.proj', 'attn.attn.out_proj')
|
||||||
else:
|
else:
|
||||||
new_k = k
|
new_k = k
|
||||||
new_k = f'backbone.{new_k}'
|
|
||||||
mmseg_dict[new_k] = v
|
mmseg_dict[new_k] = v
|
||||||
|
|
||||||
return mmseg_dict
|
return mmseg_dict
|
||||||
|
@ -24,19 +24,33 @@ def test_vit_backbone():
|
|||||||
x = torch.randn(1, 196)
|
x = torch.randn(1, 196)
|
||||||
VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
|
VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(RuntimeError):
|
||||||
# forward inputs must be [N, C, H, W]
|
# forward inputs must be [N, C, H, W]
|
||||||
x = torch.randn(3, 30, 30)
|
x = torch.randn(3, 30, 30)
|
||||||
model = VisionTransformer()
|
model = VisionTransformer()
|
||||||
model(x)
|
model(x)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
|
# The length of img_size tuple must be lower than 3.
|
||||||
VisionTransformer(img_size=(224, 224, 224))
|
VisionTransformer(img_size=(224, 224, 224))
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
# Pretrained must be None or Str.
|
||||||
|
VisionTransformer(pretrained=123)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# out_shape must be 'NLC' or 'NCHW;'
|
||||||
|
VisionTransformer(out_shape='NCL')
|
||||||
|
|
||||||
|
# Test img_size isinstance tuple
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
model = VisionTransformer(img_size=(224, ))
|
||||||
|
model.init_weights()
|
||||||
|
model(imgs)
|
||||||
|
|
||||||
# Test img_size isinstance tuple
|
# Test img_size isinstance tuple
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
model = VisionTransformer(img_size=(224, 224))
|
model = VisionTransformer(img_size=(224, 224))
|
||||||
model.init_weights()
|
|
||||||
model(imgs)
|
model(imgs)
|
||||||
|
|
||||||
# Test norm_eval = True
|
# Test norm_eval = True
|
||||||
@ -50,6 +64,11 @@ def test_vit_backbone():
|
|||||||
|
|
||||||
assert check_norm_state(model.modules(), True)
|
assert check_norm_state(model.modules(), True)
|
||||||
|
|
||||||
|
# Test normal size input image
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
|
||||||
# Test large size input image
|
# Test large size input image
|
||||||
imgs = torch.randn(1, 3, 256, 256)
|
imgs = torch.randn(1, 3, 256, 256)
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
@ -81,8 +100,20 @@ def test_vit_backbone():
|
|||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert feat[-1].shape == (1, 768, 14, 14)
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
|
||||||
|
# Test out_shape == 'NLC'
|
||||||
|
model = VisionTransformer(out_shape='NLC')
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[-1].shape == (1, 196, 768)
|
||||||
|
|
||||||
# Test final norm
|
# Test final norm
|
||||||
model = VisionTransformer(final_norm=True)
|
model = VisionTransformer(final_norm=True)
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert feat[-1].shape == (1, 768, 14, 14)
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
|
||||||
|
# Test patch norm
|
||||||
|
model = VisionTransformer(patch_norm=True)
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user