mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix some vit init bugs (#609)
* [Fix] Fix vit init bug * Add some vit unit tests * Modify module import * Fix pretrain weights bug * Modify pretrained judge * Add some unit tests to improve code cov * Optimize code * Fix vit unit test
This commit is contained in:
parent
458fc7897f
commit
0c4c3b790d
@ -1,4 +1,5 @@
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -6,8 +7,7 @@ import torch.nn.functional as F
|
||||
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
|
||||
kaiming_init, normal_init, trunc_normal_init)
|
||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||
from mmcv.runner import _load_checkpoint
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
@ -140,12 +140,6 @@ class PatchEmbed(BaseModule):
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
# assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
# f"Input image size ({H}*{W}) doesn't " \
|
||||
# f'match model ({self.img_size[0]}*{self.img_size[1]}).'
|
||||
# The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
|
||||
x = self.projection(x).flatten(2).transpose(1, 2)
|
||||
|
||||
if self.norm is not None:
|
||||
@ -185,8 +179,12 @@ class VisionTransformer(BaseModule):
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Defalut: dict(type='GELU').
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
||||
Default: False.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Default: False.
|
||||
out_shape (str): Select the output format of feature information.
|
||||
Default: NCHW.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Default: bicubic.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
@ -198,6 +196,9 @@ class VisionTransformer(BaseModule):
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
pretrain_style (str): Choose to use timm or mmcls pretrain weights.
|
||||
Default: timm.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -216,12 +217,16 @@ class VisionTransformer(BaseModule):
|
||||
with_cls_token=True,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
patch_norm=False,
|
||||
final_norm=False,
|
||||
out_shape='NCHW',
|
||||
interpolate_mode='bicubic',
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
pretrain_style='timm'):
|
||||
pretrain_style='timm',
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super(VisionTransformer, self).__init__()
|
||||
|
||||
if isinstance(img_size, int):
|
||||
@ -235,16 +240,32 @@ class VisionTransformer(BaseModule):
|
||||
|
||||
assert pretrain_style in ['timm', 'mmcls']
|
||||
|
||||
self.pretrain_style = pretrain_style
|
||||
assert out_shape in ['NLC',
|
||||
'NCHW'], 'output shape must be "NLC" or "NCHW".'
|
||||
|
||||
if isinstance(pretrained, str) or pretrained is None:
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.out_shape = out_shape
|
||||
self.interpolate_mode = interpolate_mode
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.pretrain_style = pretrain_style
|
||||
self.pretrained = pretrained
|
||||
self.init_cfg = init_cfg
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=embed_dims,
|
||||
norm_cfg=norm_cfg)
|
||||
norm_cfg=norm_cfg if patch_norm else None)
|
||||
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.with_cls_token = with_cls_token
|
||||
@ -280,24 +301,20 @@ class VisionTransformer(BaseModule):
|
||||
norm_cfg=norm_cfg,
|
||||
batch_first=True))
|
||||
|
||||
self.interpolate_mode = interpolate_mode
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str):
|
||||
def init_weights(self):
|
||||
if isinstance(self.pretrained, str):
|
||||
logger = get_root_logger()
|
||||
checkpoint = _load_checkpoint(pretrained, logger=logger)
|
||||
checkpoint = _load_checkpoint(self.pretrained, logger=logger)
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
@ -325,7 +342,8 @@ class VisionTransformer(BaseModule):
|
||||
|
||||
self.load_state_dict(state_dict, False)
|
||||
|
||||
elif pretrained is None:
|
||||
elif self.pretrained is None:
|
||||
super(VisionTransformer, self).init_weights()
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
trunc_normal_init(self.pos_embed, std=.02)
|
||||
@ -345,8 +363,6 @@ class VisionTransformer(BaseModule):
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m.bias, 0)
|
||||
constant_init(m.weight, 1.0)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def _pos_embeding(self, img, patched_img, pos_embed):
|
||||
"""Positiong embeding method.
|
||||
@ -436,10 +452,11 @@ class VisionTransformer(BaseModule):
|
||||
out = x[:, 1:]
|
||||
else:
|
||||
out = x
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, inputs.shape[2] // self.patch_size,
|
||||
inputs.shape[3] // self.patch_size,
|
||||
C).permute(0, 3, 1, 2)
|
||||
if self.out_shape == 'NCHW':
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, inputs.shape[2] // self.patch_size,
|
||||
inputs.shape[3] // self.patch_size,
|
||||
C).permute(0, 3, 1, 2)
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
@ -27,7 +27,6 @@ def vit_convert(timm_dict):
|
||||
new_k = new_k.replace('attn.proj', 'attn.attn.out_proj')
|
||||
else:
|
||||
new_k = k
|
||||
new_k = f'backbone.{new_k}'
|
||||
mmseg_dict[new_k] = v
|
||||
|
||||
return mmseg_dict
|
||||
|
@ -24,19 +24,33 @@ def test_vit_backbone():
|
||||
x = torch.randn(1, 196)
|
||||
VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(RuntimeError):
|
||||
# forward inputs must be [N, C, H, W]
|
||||
x = torch.randn(3, 30, 30)
|
||||
model = VisionTransformer()
|
||||
model(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# The length of img_size tuple must be lower than 3.
|
||||
VisionTransformer(img_size=(224, 224, 224))
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# Pretrained must be None or Str.
|
||||
VisionTransformer(pretrained=123)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# out_shape must be 'NLC' or 'NCHW;'
|
||||
VisionTransformer(out_shape='NCL')
|
||||
|
||||
# Test img_size isinstance tuple
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
model = VisionTransformer(img_size=(224, ))
|
||||
model.init_weights()
|
||||
model(imgs)
|
||||
|
||||
# Test img_size isinstance tuple
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
model = VisionTransformer(img_size=(224, 224))
|
||||
model.init_weights()
|
||||
model(imgs)
|
||||
|
||||
# Test norm_eval = True
|
||||
@ -50,6 +64,11 @@ def test_vit_backbone():
|
||||
|
||||
assert check_norm_state(model.modules(), True)
|
||||
|
||||
# Test normal size input image
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
# Test large size input image
|
||||
imgs = torch.randn(1, 3, 256, 256)
|
||||
feat = model(imgs)
|
||||
@ -81,8 +100,20 @@ def test_vit_backbone():
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
# Test out_shape == 'NLC'
|
||||
model = VisionTransformer(out_shape='NLC')
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 196, 768)
|
||||
|
||||
# Test final norm
|
||||
model = VisionTransformer(final_norm=True)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
# Test patch norm
|
||||
model = VisionTransformer(patch_norm=True)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
Loading…
x
Reference in New Issue
Block a user