From b5a814e4c14b571dc4a90e9221a1dee2ef4472b4 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Mon, 12 May 2025 00:24:15 +0800 Subject: [PATCH] add giant model param --- timm/models/beit3.py | 58 ++++++++++++++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/timm/models/beit3.py b/timm/models/beit3.py index 99b5a1ef..1bc52b32 100644 --- a/timm/models/beit3.py +++ b/timm/models/beit3.py @@ -9,12 +9,14 @@ Model from official source: @inproceedings{beit3, title={Image as a foreign language: {BEiT} pretraining for vision and vision-language tasks}, - author={Wenhui Wang and Hangbo Bao and Li Dong and Johan Bjorck and Zhiliang Peng and Qiang Liu and Kriti Aggarwal and Owais Khan Mohammed and Saksham Singhal and Subhojit Som and Furu Wei}, + author={Wenhui Wang and Hangbo Bao and Li Dong and Johan Bjorck and Zhiliang Peng and Qiang Liu and Kriti Aggarwal + and Owais Khan Mohammed and Saksham Singhal and Subhojit Som and Furu Wei}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, year={2023} } @InProceedings{Wang_2023_CVPR, - author = {Wang, Wenhui and Bao, Hangbo and Dong, Li and Bjorck, Johan and Peng, Zhiliang and Liu, Qiang and Aggarwal, Kriti and Mohammed, Owais Khan and Singhal, Saksham and Som, Subhojit and Wei, Furu}, + author = {Wang, Wenhui and Bao, Hangbo and Dong, Li and Bjorck, Johan and Peng, Zhiliang and Liu, Qiang and Aggarwal, + Kriti and Mohammed, Owais Khan and Singhal, Saksham and Som, Subhojit and Wei, Furu}, title = {Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks}, booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, month = {June}, @@ -65,6 +67,7 @@ class PositionalEmbedding(nn.Embedding): https://github.com/microsoft/torchscale/blob/main/torchscale/component/embedding.py#L99-L119 """ def forward(self, x: torch.Tensor) -> torch.Tensor: + # being consistent with Fairseq, which starts from 2. return F.embedding( torch.arange(2, self.num_embeddings).long().unsqueeze(0).to(x.device), self.weight, @@ -108,22 +111,23 @@ class Attention(nn.Module): v = self.v_proj(x) q *= self.scaling + ## (B, N, C) >> (B, N, num_heads, head_dim) >> (B, num_heads, N, head_dim) q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) + + ## (B, num_heads, N, head_dim) >> (B * num_heads, N, head_dim) q = q.reshape(B * self.num_heads, N, self.head_dim) k = k.reshape(B * self.num_heads, N, self.head_dim) v = v.reshape(B * self.num_heads, N, self.head_dim) - attn_weights = torch.bmm(q, k.transpose(1, 2)) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( - attn_weights - ) + attn_weights = torch.bmm(q, k.transpose(1, 2)) # (B * num_heads, N, N) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights) attn_probs = self.attn_drop(attn_weights) + attn = torch.bmm(attn_probs, v) # (B * num_heads, N, head_dim) - attn = torch.bmm(attn_probs, v) - attn = attn.view(B, self.num_heads, N, self.head_dim).transpose(1, 2) - attn = attn.reshape(B, N, C) + ## (B * num_heads N, head_dim) >> (B, N, num_heads * head_dim) == (B, N, C) + attn = attn.view(B, self.num_heads, N, self.head_dim).transpose(1, 2).reshape(B, N, C) attn = self.inner_attn_ln(attn) attn = self.out_proj(attn) return attn @@ -409,26 +413,28 @@ def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]: default_cfgs = generate_default_cfgs({ - 'beit3_base_patch16_224.in1k': _cfg( + 'beit3_base_patch16_224.in22k_ft_in1k': _cfg( url='https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224_in1k.pth', # hf_hub_id='timm/', ), - 'beit3_base_patch16_224.indomain_in1k': _cfg( + 'beit3_base_patch16_224.in22k_indomain_ft_in1k': _cfg( url='https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224_in1k.pth', # hf_hub_id='timm/', ), - 'beit3_large_patch16_224.in1k': _cfg( + 'beit3_large_patch16_224.in22k_ft_in1k': _cfg( url='https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224_in1k.pth', # hf_hub_id='timm/', ), - 'beit3_large_patch16_224.indomain_in1k': _cfg( + 'beit3_large_patch16_224.in22k_indomain_ft_in1k': _cfg( url='https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224_in1k.pth', # hf_hub_id='timm/', ), + 'beit3_giant_patch14_224.untrained': _cfg(url=''), + 'beit3_giant_patch14_336.untrained': _cfg(url='', input_size=(3, 336, 336)), }) -def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: BEiT3) -> Dict[str, torch.Tensor]: +def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]: if 'model' in state_dict: state_dict = state_dict['model'] @@ -459,11 +465,11 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: BEiT3) -> D k = k.replace('A.', '') out_dict[k] = v - + return out_dict -def _create_beit3(variant: str, pretrained: bool, **kwargs: Any) -> BEiT3: +def _create_beit3(variant: str, pretrained: bool = False, **kwargs: Any) -> BEiT3: out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( BEiT3, variant, pretrained, @@ -488,3 +494,23 @@ def beit3_large_patch16_224(pretrained: bool = False, **kwargs: Any) -> BEiT3: patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4) model = _create_beit3('beit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model + + +@register_model +def beit3_giant_patch14_224(pretrained: bool = False, **kwargs: Any) -> BEiT3: + ## FFN inner hidden size = embed_dim * mlp_ratio + ## 6144 = int(1408 * 4.3637) + model_args = dict( + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637) + model = _create_beit3('beit3_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def beit3_giant_patch14_336(pretrained: bool = False, **kwargs: Any) -> BEiT3: + ## FFN inner hidden size = embed_dim * mlp_ratio + ## 6144 = int(1408 * 4.3637) + model_args = dict( + img_size=336, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637) + model = _create_beit3('beit3_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model