Add option for output shape of ViT (#530)

* Add arg: final_reshape to control if converting output feature information from NLC to NCHW;

* Fix the default value of final_reshape;

* Modify arg: final_reshape to arg: out_shape;

* Fix some unit test bug;
This commit is contained in:
sennnnn 2021-05-06 13:49:28 +08:00 committed by GitHub
parent f884489120
commit aa9b609f11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 4 deletions

View File

@ -234,6 +234,8 @@ class VisionTransformer(nn.Module):
and its variants only. Default: False.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Default: False.
out_reshape (str): Select the output format of feature information.
Default: NCHW.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Default: bicubic.
with_cls_token (bool): If concatenating class token into image tokens
@ -261,6 +263,7 @@ class VisionTransformer(nn.Module):
act_cfg=dict(type='GELU'),
norm_eval=False,
final_norm=False,
out_shape='NCHW',
with_cls_token=True,
interpolate_mode='bicubic',
with_cp=False):
@ -303,6 +306,11 @@ class VisionTransformer(nn.Module):
with_cp=with_cp) for i in range(depth)
])
assert out_shape in ['NLC',
'NCHW'], 'output shape must be "NLC" or "NCHW".'
self.out_shape = out_shape
self.interpolate_mode = interpolate_mode
self.final_norm = final_norm
if final_norm:
@ -443,10 +451,11 @@ class VisionTransformer(nn.Module):
out = x[:, 1:]
else:
out = x
B, _, C = out.shape
out = out.reshape(B, inputs.shape[2] // self.patch_size,
inputs.shape[3] // self.patch_size,
C).permute(0, 3, 1, 2)
if self.out_shape == 'NCHW':
B, _, C = out.shape
out = out.reshape(B, inputs.shape[2] // self.patch_size,
inputs.shape[3] // self.patch_size,
C).permute(0, 3, 1, 2)
outs.append(out)
return tuple(outs)

View File

@ -30,6 +30,10 @@ def test_vit_backbone():
model = VisionTransformer()
model(x)
with pytest.raises(AssertionError):
# out_shape must be 'NLC' or 'NCHW;'
VisionTransformer(out_shape='NCL')
# Test img_size isinstance int
imgs = torch.randn(1, 3, 224, 224)
model = VisionTransformer(img_size=224)
@ -72,3 +76,9 @@ def test_vit_backbone():
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)
# Test final reshape arg
imgs = torch.randn(1, 3, 224, 224)
model = VisionTransformer(out_shape='NLC')
feat = model(imgs)
assert feat[-1].shape == (1, 196, 768)