[Fix] fix patch_embed and pos_embed mismatch error (#685)
* fix patch_embed and pos_embed mismatch error * add docstring * update unittest * use downsampled image shape * use tuple * remove unused parameters and add doc * fix init weights function * revise docstring * Update vit.py If -> Whether * fix lint Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>pull/716/head
parent
5097d55f8e
commit
dff7a968a3
|
@ -21,7 +21,6 @@ model = dict(
|
||||||
norm_cfg=dict(type='LN', eps=1e-6),
|
norm_cfg=dict(type='LN', eps=1e-6),
|
||||||
act_cfg=dict(type='GELU'),
|
act_cfg=dict(type='GELU'),
|
||||||
norm_eval=False,
|
norm_eval=False,
|
||||||
out_shape='NCHW',
|
|
||||||
interpolate_mode='bicubic'),
|
interpolate_mode='bicubic'),
|
||||||
neck=dict(
|
neck=dict(
|
||||||
type='MultiLevelNeck',
|
type='MultiLevelNeck',
|
||||||
|
|
|
@ -118,8 +118,10 @@ class VisionTransformer(BaseModule):
|
||||||
attn_drop_rate (float): The drop out rate for attention layer.
|
attn_drop_rate (float): The drop out rate for attention layer.
|
||||||
Default 0.0
|
Default 0.0
|
||||||
drop_path_rate (float): stochastic depth rate. Default 0.0
|
drop_path_rate (float): stochastic depth rate. Default 0.0
|
||||||
with_cls_token (bool): If concatenating class token into image tokens
|
with_cls_token (bool): Whether concatenating class token into image
|
||||||
as transformer input. Default: True.
|
tokens as transformer input. Default: True.
|
||||||
|
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||||
|
`with_cls_token` must be True. Default: False.
|
||||||
norm_cfg (dict): Config dict for normalization layer.
|
norm_cfg (dict): Config dict for normalization layer.
|
||||||
Default: dict(type='LN')
|
Default: dict(type='LN')
|
||||||
act_cfg (dict): The activation config for FFNs.
|
act_cfg (dict): The activation config for FFNs.
|
||||||
|
@ -128,8 +130,6 @@ class VisionTransformer(BaseModule):
|
||||||
Default: False.
|
Default: False.
|
||||||
final_norm (bool): Whether to add a additional layer to normalize
|
final_norm (bool): Whether to add a additional layer to normalize
|
||||||
final feature map. Default: False.
|
final feature map. Default: False.
|
||||||
out_shape (str): Select the output format of feature information.
|
|
||||||
Default: NCHW.
|
|
||||||
interpolate_mode (str): Select the interpolate mode for position
|
interpolate_mode (str): Select the interpolate mode for position
|
||||||
embeding vector resize. Default: bicubic.
|
embeding vector resize. Default: bicubic.
|
||||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||||
|
@ -160,11 +160,11 @@ class VisionTransformer(BaseModule):
|
||||||
attn_drop_rate=0.,
|
attn_drop_rate=0.,
|
||||||
drop_path_rate=0.,
|
drop_path_rate=0.,
|
||||||
with_cls_token=True,
|
with_cls_token=True,
|
||||||
|
output_cls_token=False,
|
||||||
norm_cfg=dict(type='LN'),
|
norm_cfg=dict(type='LN'),
|
||||||
act_cfg=dict(type='GELU'),
|
act_cfg=dict(type='GELU'),
|
||||||
patch_norm=False,
|
patch_norm=False,
|
||||||
final_norm=False,
|
final_norm=False,
|
||||||
out_shape='NCHW',
|
|
||||||
interpolate_mode='bicubic',
|
interpolate_mode='bicubic',
|
||||||
num_fcs=2,
|
num_fcs=2,
|
||||||
norm_eval=False,
|
norm_eval=False,
|
||||||
|
@ -185,8 +185,9 @@ class VisionTransformer(BaseModule):
|
||||||
|
|
||||||
assert pretrain_style in ['timm', 'mmcls']
|
assert pretrain_style in ['timm', 'mmcls']
|
||||||
|
|
||||||
assert out_shape in ['NLC',
|
if output_cls_token:
|
||||||
'NCHW'], 'output shape must be "NLC" or "NCHW".'
|
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||||
|
f'set output_cls_token to True, but got {with_cls_token}'
|
||||||
|
|
||||||
if isinstance(pretrained, str) or pretrained is None:
|
if isinstance(pretrained, str) or pretrained is None:
|
||||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||||
|
@ -196,7 +197,6 @@ class VisionTransformer(BaseModule):
|
||||||
|
|
||||||
self.img_size = img_size
|
self.img_size = img_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.out_shape = out_shape
|
|
||||||
self.interpolate_mode = interpolate_mode
|
self.interpolate_mode = interpolate_mode
|
||||||
self.norm_eval = norm_eval
|
self.norm_eval = norm_eval
|
||||||
self.with_cp = with_cp
|
self.with_cp = with_cp
|
||||||
|
@ -218,6 +218,7 @@ class VisionTransformer(BaseModule):
|
||||||
(img_size[1] // patch_size)
|
(img_size[1] // patch_size)
|
||||||
|
|
||||||
self.with_cls_token = with_cls_token
|
self.with_cls_token = with_cls_token
|
||||||
|
self.output_cls_token = output_cls_token
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||||
self.pos_embed = nn.Parameter(
|
self.pos_embed = nn.Parameter(
|
||||||
torch.zeros(1, num_patches + 1, embed_dims))
|
torch.zeros(1, num_patches + 1, embed_dims))
|
||||||
|
@ -253,7 +254,6 @@ class VisionTransformer(BaseModule):
|
||||||
batch_first=True))
|
batch_first=True))
|
||||||
|
|
||||||
self.final_norm = final_norm
|
self.final_norm = final_norm
|
||||||
self.out_shape = out_shape
|
|
||||||
if final_norm:
|
if final_norm:
|
||||||
self.norm1_name, norm1 = build_norm_layer(
|
self.norm1_name, norm1 = build_norm_layer(
|
||||||
norm_cfg, embed_dims, postfix=1)
|
norm_cfg, embed_dims, postfix=1)
|
||||||
|
@ -290,8 +290,9 @@ class VisionTransformer(BaseModule):
|
||||||
pos_size = int(
|
pos_size = int(
|
||||||
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
|
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
|
||||||
state_dict['pos_embed'] = self.resize_pos_embed(
|
state_dict['pos_embed'] = self.resize_pos_embed(
|
||||||
state_dict['pos_embed'], (h, w), (pos_size, pos_size),
|
state_dict['pos_embed'],
|
||||||
self.patch_size, self.interpolate_mode)
|
(h // self.patch_size, w // self.patch_size),
|
||||||
|
(pos_size, pos_size), self.interpolate_mode)
|
||||||
|
|
||||||
self.load_state_dict(state_dict, False)
|
self.load_state_dict(state_dict, False)
|
||||||
|
|
||||||
|
@ -317,16 +318,15 @@ class VisionTransformer(BaseModule):
|
||||||
constant_init(m.bias, 0)
|
constant_init(m.bias, 0)
|
||||||
constant_init(m.weight, 1.0)
|
constant_init(m.weight, 1.0)
|
||||||
|
|
||||||
def _pos_embeding(self, img, patched_img, pos_embed):
|
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
|
||||||
"""Positiong embeding method.
|
"""Positiong embeding method.
|
||||||
|
|
||||||
Resize the pos_embed, if the input image size doesn't match
|
Resize the pos_embed, if the input image size doesn't match
|
||||||
the training size.
|
the training size.
|
||||||
Args:
|
Args:
|
||||||
img (torch.Tensor): The inference image tensor, the shape
|
|
||||||
must be [B, C, H, W].
|
|
||||||
patched_img (torch.Tensor): The patched image, it should be
|
patched_img (torch.Tensor): The patched image, it should be
|
||||||
shape of [B, L1, C].
|
shape of [B, L1, C].
|
||||||
|
hw_shape (tuple): The downsampled image resolution.
|
||||||
pos_embed (torch.Tensor): The pos_embed weighs, it should be
|
pos_embed (torch.Tensor): The pos_embed weighs, it should be
|
||||||
shape of [B, L2, c].
|
shape of [B, L2, c].
|
||||||
Return:
|
Return:
|
||||||
|
@ -344,36 +344,36 @@ class VisionTransformer(BaseModule):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Unexpected shape of pos_embed, got {}.'.format(
|
'Unexpected shape of pos_embed, got {}.'.format(
|
||||||
pos_embed.shape))
|
pos_embed.shape))
|
||||||
pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:],
|
pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
|
||||||
(pos_h, pos_w), self.patch_size,
|
(pos_h, pos_w),
|
||||||
self.interpolate_mode)
|
self.interpolate_mode)
|
||||||
return self.drop_after_pos(patched_img + pos_embed)
|
return self.drop_after_pos(patched_img + pos_embed)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
|
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
|
||||||
"""Resize pos_embed weights.
|
"""Resize pos_embed weights.
|
||||||
|
|
||||||
Resize pos_embed using bicubic interpolate method.
|
Resize pos_embed using bicubic interpolate method.
|
||||||
Args:
|
Args:
|
||||||
pos_embed (torch.Tensor): pos_embed weights.
|
pos_embed (torch.Tensor): Position embedding weights.
|
||||||
input_shpae (tuple): Tuple for (input_h, intput_w).
|
input_shpae (tuple): Tuple for (downsampled input image height,
|
||||||
pos_shape (tuple): Tuple for (pos_h, pos_w).
|
downsampled input image width).
|
||||||
patch_size (int): Patch size.
|
pos_shape (tuple): The resolution of downsampled origin training
|
||||||
|
image.
|
||||||
|
mode (str): Algorithm used for upsampling:
|
||||||
|
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||||
|
``'trilinear'``. Default: ``'nearest'``
|
||||||
Return:
|
Return:
|
||||||
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
|
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
|
||||||
"""
|
"""
|
||||||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
||||||
input_h, input_w = input_shpae
|
|
||||||
pos_h, pos_w = pos_shape
|
pos_h, pos_w = pos_shape
|
||||||
cls_token_weight = pos_embed[:, 0]
|
cls_token_weight = pos_embed[:, 0]
|
||||||
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
||||||
pos_embed_weight = pos_embed_weight.reshape(
|
pos_embed_weight = pos_embed_weight.reshape(
|
||||||
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
||||||
pos_embed_weight = F.interpolate(
|
pos_embed_weight = F.interpolate(
|
||||||
pos_embed_weight,
|
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
|
||||||
size=[input_h // patch_size, input_w // patch_size],
|
|
||||||
align_corners=False,
|
|
||||||
mode=mode)
|
|
||||||
cls_token_weight = cls_token_weight.unsqueeze(1)
|
cls_token_weight = cls_token_weight.unsqueeze(1)
|
||||||
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
||||||
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
||||||
|
@ -382,12 +382,12 @@ class VisionTransformer(BaseModule):
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
B = inputs.shape[0]
|
B = inputs.shape[0]
|
||||||
|
|
||||||
x = self.patch_embed(inputs)
|
x, hw_shape = self.patch_embed(inputs), (self.patch_embed.DH,
|
||||||
|
self.patch_embed.DW)
|
||||||
# stole cls_tokens impl from Phil Wang, thanks
|
# stole cls_tokens impl from Phil Wang, thanks
|
||||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||||
x = torch.cat((cls_tokens, x), dim=1)
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
x = self._pos_embeding(inputs, x, self.pos_embed)
|
x = self._pos_embeding(x, hw_shape, self.pos_embed)
|
||||||
|
|
||||||
if not self.with_cls_token:
|
if not self.with_cls_token:
|
||||||
# Remove class token for transformer encoder input
|
# Remove class token for transformer encoder input
|
||||||
|
@ -405,11 +405,11 @@ class VisionTransformer(BaseModule):
|
||||||
out = x[:, 1:]
|
out = x[:, 1:]
|
||||||
else:
|
else:
|
||||||
out = x
|
out = x
|
||||||
if self.out_shape == 'NCHW':
|
B, _, C = out.shape
|
||||||
B, _, C = out.shape
|
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||||
out = out.reshape(B, inputs.shape[2] // self.patch_size,
|
C).permute(0, 3, 1, 2)
|
||||||
inputs.shape[3] // self.patch_size,
|
if self.output_cls_token:
|
||||||
C).permute(0, 3, 1, 2)
|
out = [out, x[:, 0]]
|
||||||
outs.append(out)
|
outs.append(out)
|
||||||
|
|
||||||
return tuple(outs)
|
return tuple(outs)
|
||||||
|
|
|
@ -39,8 +39,8 @@ def test_vit_backbone():
|
||||||
VisionTransformer(pretrained=123)
|
VisionTransformer(pretrained=123)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# out_shape must be 'NLC' or 'NCHW;'
|
# with_cls_token must be True when output_cls_token == True
|
||||||
VisionTransformer(out_shape='NCL')
|
VisionTransformer(with_cls_token=False, output_cls_token=True)
|
||||||
|
|
||||||
# Test img_size isinstance tuple
|
# Test img_size isinstance tuple
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
@ -88,6 +88,11 @@ def test_vit_backbone():
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert feat[-1].shape == (1, 768, 7, 14)
|
assert feat[-1].shape == (1, 768, 7, 14)
|
||||||
|
|
||||||
|
# Test irregular input image
|
||||||
|
imgs = torch.randn(1, 3, 234, 345)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[-1].shape == (1, 768, 15, 22)
|
||||||
|
|
||||||
# Test with_cp=True
|
# Test with_cp=True
|
||||||
model = VisionTransformer(with_cp=True)
|
model = VisionTransformer(with_cp=True)
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
@ -100,12 +105,6 @@ def test_vit_backbone():
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert feat[-1].shape == (1, 768, 14, 14)
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
|
||||||
# Test out_shape == 'NLC'
|
|
||||||
model = VisionTransformer(out_shape='NLC')
|
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
|
||||||
feat = model(imgs)
|
|
||||||
assert feat[-1].shape == (1, 196, 768)
|
|
||||||
|
|
||||||
# Test final norm
|
# Test final norm
|
||||||
model = VisionTransformer(final_norm=True)
|
model = VisionTransformer(final_norm=True)
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
@ -117,3 +116,10 @@ def test_vit_backbone():
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert feat[-1].shape == (1, 768, 14, 14)
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
|
||||||
|
# Test output_cls_token
|
||||||
|
model = VisionTransformer(with_cls_token=True, output_cls_token=True)
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[0][0].shape == (1, 768, 14, 14)
|
||||||
|
assert feat[0][1].shape == (1, 768)
|
||||||
|
|
Loading…
Reference in New Issue