mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Add option for output shape of ViT (#530)
* Add arg: final_reshape to control if converting output feature information from NLC to NCHW; * Fix the default value of final_reshape; * Modify arg: final_reshape to arg: out_shape; * Fix some unit test bug;
This commit is contained in:
parent
f884489120
commit
aa9b609f11
@ -234,6 +234,8 @@ class VisionTransformer(nn.Module):
|
|||||||
and its variants only. Default: False.
|
and its variants only. Default: False.
|
||||||
final_norm (bool): Whether to add a additional layer to normalize
|
final_norm (bool): Whether to add a additional layer to normalize
|
||||||
final feature map. Default: False.
|
final feature map. Default: False.
|
||||||
|
out_reshape (str): Select the output format of feature information.
|
||||||
|
Default: NCHW.
|
||||||
interpolate_mode (str): Select the interpolate mode for position
|
interpolate_mode (str): Select the interpolate mode for position
|
||||||
embeding vector resize. Default: bicubic.
|
embeding vector resize. Default: bicubic.
|
||||||
with_cls_token (bool): If concatenating class token into image tokens
|
with_cls_token (bool): If concatenating class token into image tokens
|
||||||
@ -261,6 +263,7 @@ class VisionTransformer(nn.Module):
|
|||||||
act_cfg=dict(type='GELU'),
|
act_cfg=dict(type='GELU'),
|
||||||
norm_eval=False,
|
norm_eval=False,
|
||||||
final_norm=False,
|
final_norm=False,
|
||||||
|
out_shape='NCHW',
|
||||||
with_cls_token=True,
|
with_cls_token=True,
|
||||||
interpolate_mode='bicubic',
|
interpolate_mode='bicubic',
|
||||||
with_cp=False):
|
with_cp=False):
|
||||||
@ -303,6 +306,11 @@ class VisionTransformer(nn.Module):
|
|||||||
with_cp=with_cp) for i in range(depth)
|
with_cp=with_cp) for i in range(depth)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
assert out_shape in ['NLC',
|
||||||
|
'NCHW'], 'output shape must be "NLC" or "NCHW".'
|
||||||
|
|
||||||
|
self.out_shape = out_shape
|
||||||
|
|
||||||
self.interpolate_mode = interpolate_mode
|
self.interpolate_mode = interpolate_mode
|
||||||
self.final_norm = final_norm
|
self.final_norm = final_norm
|
||||||
if final_norm:
|
if final_norm:
|
||||||
@ -443,6 +451,7 @@ class VisionTransformer(nn.Module):
|
|||||||
out = x[:, 1:]
|
out = x[:, 1:]
|
||||||
else:
|
else:
|
||||||
out = x
|
out = x
|
||||||
|
if self.out_shape == 'NCHW':
|
||||||
B, _, C = out.shape
|
B, _, C = out.shape
|
||||||
out = out.reshape(B, inputs.shape[2] // self.patch_size,
|
out = out.reshape(B, inputs.shape[2] // self.patch_size,
|
||||||
inputs.shape[3] // self.patch_size,
|
inputs.shape[3] // self.patch_size,
|
||||||
|
@ -30,6 +30,10 @@ def test_vit_backbone():
|
|||||||
model = VisionTransformer()
|
model = VisionTransformer()
|
||||||
model(x)
|
model(x)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# out_shape must be 'NLC' or 'NCHW;'
|
||||||
|
VisionTransformer(out_shape='NCL')
|
||||||
|
|
||||||
# Test img_size isinstance int
|
# Test img_size isinstance int
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
model = VisionTransformer(img_size=224)
|
model = VisionTransformer(img_size=224)
|
||||||
@ -72,3 +76,9 @@ def test_vit_backbone():
|
|||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert feat[-1].shape == (1, 768, 14, 14)
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
|
||||||
|
# Test final reshape arg
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
model = VisionTransformer(out_shape='NLC')
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[-1].shape == (1, 196, 768)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user