diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 557ee502..b356e129 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -1,5 +1,5 @@ from .beit import * -from .beit3 import * +#from .beit3 import * from .byoanet import * from .byobnet import * from .cait import * diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 3c7b9a22..6f5e144d 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -64,6 +64,7 @@ class Attention(nn.Module): num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, proj_bias: bool = True, attn_drop: float = 0., proj_drop: float = 0., @@ -79,6 +80,7 @@ class Attention(nn.Module): self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.norm = norm_layer(dim) if scale_attn_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) @@ -102,6 +104,7 @@ class Attention(nn.Module): x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) + x = self.norm(x) x = self.proj(x) x = self.proj_drop(x) return x @@ -130,6 +133,8 @@ class Block(nn.Module): mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., @@ -146,6 +151,7 @@ class Block(nn.Module): num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + scale_attn_norm=scale_attn_norm, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, @@ -159,6 +165,7 @@ class Block(nn.Module): in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + norm_layer=norm_layer if scale_mlp_norm else None, bias=proj_bias, drop=proj_drop, ) @@ -179,6 +186,8 @@ class ResPostBlock(nn.Module): mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., @@ -196,6 +205,7 @@ class ResPostBlock(nn.Module): num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + scale_attn_norm=scale_attn_norm, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, @@ -208,6 +218,7 @@ class ResPostBlock(nn.Module): in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + norm_layer=norm_layer if scale_mlp_norm else None, bias=proj_bias, drop=proj_drop, ) @@ -443,6 +454,8 @@ class VisionTransformer(nn.Module): mlp_ratio: float = 4., qkv_bias: bool = True, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, init_values: Optional[float] = None, class_token: bool = True, @@ -563,6 +576,8 @@ class VisionTransformer(nn.Module): mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_norm=qk_norm, + scale_attn_norm=scale_attn_norm, + scale_mlp_norm=scale_mlp_norm, proj_bias=proj_bias, init_values=init_values, proj_drop=proj_drop_rate, @@ -1166,6 +1181,127 @@ def _convert_aimv2( return out_dict +def _convert_beit3( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, +) -> Dict[str, torch.Tensor]: + """Convert BEiT3 weights to standard VisionTransformer format. + + First applies BEiT3's own filtering (from multimodal to vision-only BEiT3 format), + then converts from BEiT3 format to standard VisionTransformer format. + """ + import re + + # Step 1: Apply BEiT3's own checkpoint filtering logic + # (equivalent to beit3.checkpoint_filter_fn) + if 'model' in state_dict: + state_dict = state_dict['model'] + + # If already processed, skip BEiT3 filtering + if 'patch_embed.proj.weight' in state_dict: + intermediate_dict = state_dict + else: + # Remove text and mask tokens (vision-only) + state_dict.pop('beit3.text_embed.weight', None) + state_dict.pop('beit3.vision_embed.mask_token', None) + + intermediate_dict = {} + + for k, v in state_dict.items(): + # Skip B branch weights (use only A branch) + if '.B.' in k: + continue + elif 'vision_embed.cls_token' in k: + k = 'cls_token' + else: + # Apply BEiT3's key transformations + k = k.replace('beit3.', '') + k = k.replace('embed_positions.', 'pos_embed.') + k = k.replace('vision_embed.', 'patch_embed.') + k = k.replace('encoder.', '') + k = k.replace('layers.', 'blocks.') + k = k.replace('ffn.', 'mlp.') + k = k.replace('ffn_layernorm.', 'norm.') + k = k.replace('self_attn.', 'attn.') + k = k.replace('self_attn_layer_norm.', 'norm1.') + k = k.replace('final_layer_norm.', 'norm2.') + k = k.replace('A.', '') # Remove A branch prefix + + intermediate_dict[k] = v + + # Step 2: Convert from BEiT3 format to VisionTransformer format + out_dict = {} + + for k, v in intermediate_dict.items(): + # Handle attention projections - convert separate q,k,v to fused qkv + if re.match(r"blocks\.(\d+)\.attn\.[qkv]_proj\.weight", k): + block_idx = re.search(r"blocks\.(\d+)", k).group(1) + proj_type = re.search(r"\.([qkv])_proj", k).group(1) + + # Collect all three projections for this block + q_key = f"blocks.{block_idx}.attn.q_proj.weight" + k_key = f"blocks.{block_idx}.attn.k_proj.weight" + v_key = f"blocks.{block_idx}.attn.v_proj.weight" + + if all(key in intermediate_dict for key in [q_key, k_key, v_key]): + # Only create qkv weight once when we encounter the first projection + if proj_type == 'q': + qkv_weight = torch.cat([ + intermediate_dict[q_key], + intermediate_dict[k_key], + intermediate_dict[v_key] + ], dim=0) + out_dict[f"blocks.{block_idx}.attn.qkv.weight"] = qkv_weight + # Skip k and v projections as they're handled with q + continue + else: + # Fallback if not all projections available + out_dict[k.replace('q_proj', 'qkv').replace('k_proj', 'qkv').replace('v_proj', 'qkv')] = v + + # Handle attention projection biases + elif re.match(r"blocks\.(\d+)\.attn\.[qkv]_proj\.bias", k): + block_idx = re.search(r"blocks\.(\d+)", k).group(1) + proj_type = re.search(r"\.([qkv])_proj", k).group(1) + + q_key = f"blocks.{block_idx}.attn.q_proj.bias" + k_key = f"blocks.{block_idx}.attn.k_proj.bias" + v_key = f"blocks.{block_idx}.attn.v_proj.bias" + + if all(key in intermediate_dict for key in [q_key, k_key, v_key]): + if proj_type == 'q': + qkv_bias = torch.cat([ + intermediate_dict[q_key], + intermediate_dict[k_key], + intermediate_dict[v_key] + ], dim=0) + out_dict[f"blocks.{block_idx}.attn.qkv.bias"] = qkv_bias + continue + else: + out_dict[k.replace('q_proj', 'qkv').replace('k_proj', 'qkv').replace('v_proj', 'qkv')] = v + + # Map inner attention LayerNorm to scale norm + elif 'attn.inner_attn_ln' in k: + out_dict[k.replace('inner_attn_ln', 'norm')] = v + + # Map out_proj to proj + elif 'attn.out_proj' in k: + out_dict[k.replace('out_proj', 'proj')] = v + elif 'attn.proj' in k: + out_dict[k] = v + + # Handle positional embedding - skip first 2 positions (BEiT3 starts from index 2) + elif k == 'pos_embed.weight': + # BEiT3 pos_embed.weight has shape [num_patches + 3, embed_dim] + # We want [1, num_patches + 1, embed_dim] for standard ViT (cls token + patches) + out_dict['pos_embed'] = v[2:].unsqueeze(0) # Skip first 2 positions, add batch dim + + # Pass through other weights unchanged + else: + out_dict[k] = v + + return out_dict + + def checkpoint_filter_fn( state_dict: Dict[str, torch.Tensor], model: VisionTransformer, @@ -1186,6 +1322,9 @@ def checkpoint_filter_fn( state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.') elif "mask_token" in state_dict: state_dict = _convert_dinov2(state_dict, model) + elif any('beit3.' in k for k in state_dict.keys()): + # BEiT3 model - multimodal checkpoint with beit3.* prefix + state_dict = _convert_beit3(state_dict, model) elif "encoder" in state_dict: # IJEPA, vit in an 'encoder' submodule state_dict = state_dict['encoder'] @@ -2377,6 +2516,24 @@ default_cfgs = { input_size=(3, 160, 160), crop_pct=0.95), 'test_vit4.r160_in1k': _cfg( input_size=(3, 160, 160), crop_pct=0.95), + + # BEiT3 models (remapped to VisionTransformer with scale_norm=True) + 'beit3_base_patch16_224.in22k_ft_in1k': _cfg( + url='https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_base_patch16_224.in22k_indomain_ft_in1k': _cfg( + url='https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_large_patch16_224.in22k_ft_in1k': _cfg( + url='https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_large_patch16_224.in22k_indomain_ft_in1k': _cfg( + url='https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_giant_patch14_224.untrained': _cfg( + url='', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_giant_patch14_336.untrained': _cfg( + url='', input_size=(3, 336, 336), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), } _quick_gelu_cfgs = [n for n, c in default_cfgs.items() if c.get('notes', ()) and 'quickgelu' in c['notes'][0]] @@ -4035,6 +4192,58 @@ def test_vit4(pretrained: bool = False, **kwargs) -> VisionTransformer: return model +@register_model +def beit3_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ BEiT3 Base model (ViT-Base size) with patch size 16x16. + Remapped to VisionTransformer with scale_norm=True. + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg' + ) + model = _create_vision_transformer('beit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def beit3_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ BEiT3 Large model (ViT-Large size) with patch size 16x16. + Remapped to VisionTransformer with scale_norm=True. + """ + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, + scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg' + ) + model = _create_vision_transformer('beit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def beit3_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ BEiT3 Giant model with patch size 14x14. + Remapped to VisionTransformer with scale_norm=True. + """ + model_args = dict( + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637, + scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg' + ) + model = _create_vision_transformer('beit3_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def beit3_giant_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ BEiT3 Giant model with patch size 14x14 and image size 336x336. + Remapped to VisionTransformer with scale_norm=True. + """ + model_args = dict( + img_size=336, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637, + scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg' + ) + model = _create_vision_transformer('beit3_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + register_model_deprecations(__name__, { 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k', 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',