update BEiT3

This commit is contained in:
Ryan 2025-05-12 00:13:52 +08:00
parent 008514934c
commit afe4375e77
2 changed files with 7 additions and 9 deletions

View File

@ -56,13 +56,13 @@ FEAT_INTER_FILTERS = [
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*',
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
'davit', 'rdnet', 'convnext', 'pit'
'davit', 'rdnet', 'convnext', 'pit', 'beit3',
]
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*', 'beit3*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*',
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
]

View File

@ -86,7 +86,7 @@ class Attention(nn.Module):
dim: int,
num_heads: int,
drop_rate: float = 0.,
norm_layer: LayerType = partial(LayerNorm, eps=1e-5)
norm_layer: LayerType = partial(LayerNorm, eps=1e-5),
):
super().__init__()
self.num_heads = num_heads
@ -122,7 +122,8 @@ class Attention(nn.Module):
attn_probs = self.attn_drop(attn_weights)
attn = torch.bmm(attn_probs, v)
attn = attn.transpose(0, 1).reshape(N, B, C).transpose(0, 1)
attn = attn.view(B, self.num_heads, N, self.head_dim).transpose(1, 2)
attn = attn.reshape(B, N, C)
attn = self.inner_attn_ln(attn)
attn = self.out_proj(attn)
return attn
@ -403,7 +404,7 @@ def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
'paper_ids': 'arXiv:2208.10442',
'paper_name': 'Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks',
'origin_url': 'https://github.com/microsoft/unilm/tree/master/beit3',
**kwargs
**kwargs,
}
@ -427,10 +428,7 @@ default_cfgs = generate_default_cfgs({
})
def checkpoint_filter_fn(
state_dict: Dict[str, torch.Tensor],
model: BEiT3,
) -> Dict[str, torch.Tensor]:
def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: BEiT3) -> Dict[str, torch.Tensor]:
if 'model' in state_dict:
state_dict = state_dict['model']