Fix class token support in MViT-V2, add small_class variant to ensure it's tested. Fix #1443
parent
b94b7cea65
commit
f66e5f0e35
|
@ -135,6 +135,11 @@ model_cfgs = dict(
|
|||
num_heads=2,
|
||||
expand_attn=False,
|
||||
),
|
||||
|
||||
mvitv2_small_cls=MultiScaleVitCfg(
|
||||
depths=(1, 2, 11, 2),
|
||||
use_cls_token=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
@ -641,7 +646,7 @@ class MultiScaleBlock(nn.Module):
|
|||
if self.shortcut_pool_attn is None:
|
||||
return x
|
||||
if self.has_cls_token:
|
||||
cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :]
|
||||
cls_tok, x = x[:, :1, :], x[:, 1:, :]
|
||||
else:
|
||||
cls_tok = None
|
||||
B, L, C = x.shape
|
||||
|
@ -650,7 +655,7 @@ class MultiScaleBlock(nn.Module):
|
|||
x = self.shortcut_pool_attn(x)
|
||||
x = x.reshape(B, C, -1).transpose(1, 2)
|
||||
if cls_tok is not None:
|
||||
x = torch.cat((cls_tok, x), dim=2)
|
||||
x = torch.cat((cls_tok, x), dim=1)
|
||||
return x
|
||||
|
||||
def forward(self, x, feat_size: List[int]):
|
||||
|
@ -996,3 +1001,8 @@ def mvitv2_large(pretrained=False, **kwargs):
|
|||
# @register_model
|
||||
# def mvitv2_huge_in21k(pretrained=False, **kwargs):
|
||||
# return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mvitv2_small_cls(pretrained=False, **kwargs):
|
||||
return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue