add giant model param

This commit is contained in:
Ryan 2025-05-12 00:24:15 +08:00
parent afe4375e77
commit b5a814e4c1

View File

@ -9,12 +9,14 @@ Model from official source:
@inproceedings{beit3,
title={Image as a foreign language: {BEiT} pretraining for vision and vision-language tasks},
author={Wenhui Wang and Hangbo Bao and Li Dong and Johan Bjorck and Zhiliang Peng and Qiang Liu and Kriti Aggarwal and Owais Khan Mohammed and Saksham Singhal and Subhojit Som and Furu Wei},
author={Wenhui Wang and Hangbo Bao and Li Dong and Johan Bjorck and Zhiliang Peng and Qiang Liu and Kriti Aggarwal
and Owais Khan Mohammed and Saksham Singhal and Subhojit Som and Furu Wei},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2023}
}
@InProceedings{Wang_2023_CVPR,
author = {Wang, Wenhui and Bao, Hangbo and Dong, Li and Bjorck, Johan and Peng, Zhiliang and Liu, Qiang and Aggarwal, Kriti and Mohammed, Owais Khan and Singhal, Saksham and Som, Subhojit and Wei, Furu},
author = {Wang, Wenhui and Bao, Hangbo and Dong, Li and Bjorck, Johan and Peng, Zhiliang and Liu, Qiang and Aggarwal,
Kriti and Mohammed, Owais Khan and Singhal, Saksham and Som, Subhojit and Wei, Furu},
title = {Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
@ -65,6 +67,7 @@ class PositionalEmbedding(nn.Embedding):
https://github.com/microsoft/torchscale/blob/main/torchscale/component/embedding.py#L99-L119
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
# being consistent with Fairseq, which starts from 2.
return F.embedding(
torch.arange(2, self.num_embeddings).long().unsqueeze(0).to(x.device),
self.weight,
@ -108,22 +111,23 @@ class Attention(nn.Module):
v = self.v_proj(x)
q *= self.scaling
## (B, N, C) >> (B, N, num_heads, head_dim) >> (B, num_heads, N, head_dim)
q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
## (B, num_heads, N, head_dim) >> (B * num_heads, N, head_dim)
q = q.reshape(B * self.num_heads, N, self.head_dim)
k = k.reshape(B * self.num_heads, N, self.head_dim)
v = v.reshape(B * self.num_heads, N, self.head_dim)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
attn_weights
)
attn_weights = torch.bmm(q, k.transpose(1, 2)) # (B * num_heads, N, N)
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
attn_probs = self.attn_drop(attn_weights)
attn = torch.bmm(attn_probs, v) # (B * num_heads, N, head_dim)
attn = torch.bmm(attn_probs, v)
attn = attn.view(B, self.num_heads, N, self.head_dim).transpose(1, 2)
attn = attn.reshape(B, N, C)
## (B * num_heads N, head_dim) >> (B, N, num_heads * head_dim) == (B, N, C)
attn = attn.view(B, self.num_heads, N, self.head_dim).transpose(1, 2).reshape(B, N, C)
attn = self.inner_attn_ln(attn)
attn = self.out_proj(attn)
return attn
@ -409,26 +413,28 @@ def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
default_cfgs = generate_default_cfgs({
'beit3_base_patch16_224.in1k': _cfg(
'beit3_base_patch16_224.in22k_ft_in1k': _cfg(
url='https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224_in1k.pth',
# hf_hub_id='timm/',
),
'beit3_base_patch16_224.indomain_in1k': _cfg(
'beit3_base_patch16_224.in22k_indomain_ft_in1k': _cfg(
url='https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224_in1k.pth',
# hf_hub_id='timm/',
),
'beit3_large_patch16_224.in1k': _cfg(
'beit3_large_patch16_224.in22k_ft_in1k': _cfg(
url='https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224_in1k.pth',
# hf_hub_id='timm/',
),
'beit3_large_patch16_224.indomain_in1k': _cfg(
'beit3_large_patch16_224.in22k_indomain_ft_in1k': _cfg(
url='https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224_in1k.pth',
# hf_hub_id='timm/',
),
'beit3_giant_patch14_224.untrained': _cfg(url=''),
'beit3_giant_patch14_336.untrained': _cfg(url='', input_size=(3, 336, 336)),
})
def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: BEiT3) -> Dict[str, torch.Tensor]:
def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
if 'model' in state_dict:
state_dict = state_dict['model']
@ -459,11 +465,11 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: BEiT3) -> D
k = k.replace('A.', '')
out_dict[k] = v
return out_dict
def _create_beit3(variant: str, pretrained: bool, **kwargs: Any) -> BEiT3:
def _create_beit3(variant: str, pretrained: bool = False, **kwargs: Any) -> BEiT3:
out_indices = kwargs.pop('out_indices', 3)
model = build_model_with_cfg(
BEiT3, variant, pretrained,
@ -488,3 +494,23 @@ def beit3_large_patch16_224(pretrained: bool = False, **kwargs: Any) -> BEiT3:
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4)
model = _create_beit3('beit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def beit3_giant_patch14_224(pretrained: bool = False, **kwargs: Any) -> BEiT3:
## FFN inner hidden size = embed_dim * mlp_ratio
## 6144 = int(1408 * 4.3637)
model_args = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637)
model = _create_beit3('beit3_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def beit3_giant_patch14_336(pretrained: bool = False, **kwargs: Any) -> BEiT3:
## FFN inner hidden size = embed_dim * mlp_ratio
## 6144 = int(1408 * 4.3637)
model_args = dict(
img_size=336, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637)
model = _create_beit3('beit3_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model