mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Add option for output shape of ViT (#530)
* Add arg: final_reshape to control if converting output feature information from NLC to NCHW; * Fix the default value of final_reshape; * Modify arg: final_reshape to arg: out_shape; * Fix some unit test bug;
This commit is contained in:
parent
f884489120
commit
aa9b609f11
@ -234,6 +234,8 @@ class VisionTransformer(nn.Module):
|
||||
and its variants only. Default: False.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Default: False.
|
||||
out_reshape (str): Select the output format of feature information.
|
||||
Default: NCHW.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Default: bicubic.
|
||||
with_cls_token (bool): If concatenating class token into image tokens
|
||||
@ -261,6 +263,7 @@ class VisionTransformer(nn.Module):
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_eval=False,
|
||||
final_norm=False,
|
||||
out_shape='NCHW',
|
||||
with_cls_token=True,
|
||||
interpolate_mode='bicubic',
|
||||
with_cp=False):
|
||||
@ -303,6 +306,11 @@ class VisionTransformer(nn.Module):
|
||||
with_cp=with_cp) for i in range(depth)
|
||||
])
|
||||
|
||||
assert out_shape in ['NLC',
|
||||
'NCHW'], 'output shape must be "NLC" or "NCHW".'
|
||||
|
||||
self.out_shape = out_shape
|
||||
|
||||
self.interpolate_mode = interpolate_mode
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
@ -443,10 +451,11 @@ class VisionTransformer(nn.Module):
|
||||
out = x[:, 1:]
|
||||
else:
|
||||
out = x
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, inputs.shape[2] // self.patch_size,
|
||||
inputs.shape[3] // self.patch_size,
|
||||
C).permute(0, 3, 1, 2)
|
||||
if self.out_shape == 'NCHW':
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, inputs.shape[2] // self.patch_size,
|
||||
inputs.shape[3] // self.patch_size,
|
||||
C).permute(0, 3, 1, 2)
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
@ -30,6 +30,10 @@ def test_vit_backbone():
|
||||
model = VisionTransformer()
|
||||
model(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# out_shape must be 'NLC' or 'NCHW;'
|
||||
VisionTransformer(out_shape='NCL')
|
||||
|
||||
# Test img_size isinstance int
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
model = VisionTransformer(img_size=224)
|
||||
@ -72,3 +76,9 @@ def test_vit_backbone():
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
# Test final reshape arg
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
model = VisionTransformer(out_shape='NLC')
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 196, 768)
|
||||
|
Loading…
x
Reference in New Issue
Block a user