Initial run through remapping beit3 -> vision_transformer.py

This commit is contained in:
Ross Wightman 2025-05-29 09:50:17 -07:00
parent 9790fea406
commit 55e52c45ef
2 changed files with 210 additions and 1 deletions

View File

@ -1,5 +1,5 @@
from .beit import *
from .beit3 import *
#from .beit3 import *
from .byoanet import *
from .byobnet import *
from .cait import *

View File

@ -64,6 +64,7 @@ class Attention(nn.Module):
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
scale_attn_norm: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
@ -79,6 +80,7 @@ class Attention(nn.Module):
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.norm = norm_layer(dim) if scale_attn_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
@ -102,6 +104,7 @@ class Attention(nn.Module):
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.norm(x)
x = self.proj(x)
x = self.proj_drop(x)
return x
@ -130,6 +133,8 @@ class Block(nn.Module):
mlp_ratio: float = 4.,
qkv_bias: bool = False,
qk_norm: bool = False,
scale_attn_norm: bool = False,
scale_mlp_norm: bool = False,
proj_bias: bool = True,
proj_drop: float = 0.,
attn_drop: float = 0.,
@ -146,6 +151,7 @@ class Block(nn.Module):
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
scale_attn_norm=scale_attn_norm,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
@ -159,6 +165,7 @@ class Block(nn.Module):
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
norm_layer=norm_layer if scale_mlp_norm else None,
bias=proj_bias,
drop=proj_drop,
)
@ -179,6 +186,8 @@ class ResPostBlock(nn.Module):
mlp_ratio: float = 4.,
qkv_bias: bool = False,
qk_norm: bool = False,
scale_attn_norm: bool = False,
scale_mlp_norm: bool = False,
proj_bias: bool = True,
proj_drop: float = 0.,
attn_drop: float = 0.,
@ -196,6 +205,7 @@ class ResPostBlock(nn.Module):
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
scale_attn_norm=scale_attn_norm,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
@ -208,6 +218,7 @@ class ResPostBlock(nn.Module):
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
norm_layer=norm_layer if scale_mlp_norm else None,
bias=proj_bias,
drop=proj_drop,
)
@ -443,6 +454,8 @@ class VisionTransformer(nn.Module):
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_norm: bool = False,
scale_attn_norm: bool = False,
scale_mlp_norm: bool = False,
proj_bias: bool = True,
init_values: Optional[float] = None,
class_token: bool = True,
@ -563,6 +576,8 @@ class VisionTransformer(nn.Module):
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
scale_attn_norm=scale_attn_norm,
scale_mlp_norm=scale_mlp_norm,
proj_bias=proj_bias,
init_values=init_values,
proj_drop=proj_drop_rate,
@ -1166,6 +1181,127 @@ def _convert_aimv2(
return out_dict
def _convert_beit3(
state_dict: Dict[str, torch.Tensor],
model: VisionTransformer,
) -> Dict[str, torch.Tensor]:
"""Convert BEiT3 weights to standard VisionTransformer format.
First applies BEiT3's own filtering (from multimodal to vision-only BEiT3 format),
then converts from BEiT3 format to standard VisionTransformer format.
"""
import re
# Step 1: Apply BEiT3's own checkpoint filtering logic
# (equivalent to beit3.checkpoint_filter_fn)
if 'model' in state_dict:
state_dict = state_dict['model']
# If already processed, skip BEiT3 filtering
if 'patch_embed.proj.weight' in state_dict:
intermediate_dict = state_dict
else:
# Remove text and mask tokens (vision-only)
state_dict.pop('beit3.text_embed.weight', None)
state_dict.pop('beit3.vision_embed.mask_token', None)
intermediate_dict = {}
for k, v in state_dict.items():
# Skip B branch weights (use only A branch)
if '.B.' in k:
continue
elif 'vision_embed.cls_token' in k:
k = 'cls_token'
else:
# Apply BEiT3's key transformations
k = k.replace('beit3.', '')
k = k.replace('embed_positions.', 'pos_embed.')
k = k.replace('vision_embed.', 'patch_embed.')
k = k.replace('encoder.', '')
k = k.replace('layers.', 'blocks.')
k = k.replace('ffn.', 'mlp.')
k = k.replace('ffn_layernorm.', 'norm.')
k = k.replace('self_attn.', 'attn.')
k = k.replace('self_attn_layer_norm.', 'norm1.')
k = k.replace('final_layer_norm.', 'norm2.')
k = k.replace('A.', '') # Remove A branch prefix
intermediate_dict[k] = v
# Step 2: Convert from BEiT3 format to VisionTransformer format
out_dict = {}
for k, v in intermediate_dict.items():
# Handle attention projections - convert separate q,k,v to fused qkv
if re.match(r"blocks\.(\d+)\.attn\.[qkv]_proj\.weight", k):
block_idx = re.search(r"blocks\.(\d+)", k).group(1)
proj_type = re.search(r"\.([qkv])_proj", k).group(1)
# Collect all three projections for this block
q_key = f"blocks.{block_idx}.attn.q_proj.weight"
k_key = f"blocks.{block_idx}.attn.k_proj.weight"
v_key = f"blocks.{block_idx}.attn.v_proj.weight"
if all(key in intermediate_dict for key in [q_key, k_key, v_key]):
# Only create qkv weight once when we encounter the first projection
if proj_type == 'q':
qkv_weight = torch.cat([
intermediate_dict[q_key],
intermediate_dict[k_key],
intermediate_dict[v_key]
], dim=0)
out_dict[f"blocks.{block_idx}.attn.qkv.weight"] = qkv_weight
# Skip k and v projections as they're handled with q
continue
else:
# Fallback if not all projections available
out_dict[k.replace('q_proj', 'qkv').replace('k_proj', 'qkv').replace('v_proj', 'qkv')] = v
# Handle attention projection biases
elif re.match(r"blocks\.(\d+)\.attn\.[qkv]_proj\.bias", k):
block_idx = re.search(r"blocks\.(\d+)", k).group(1)
proj_type = re.search(r"\.([qkv])_proj", k).group(1)
q_key = f"blocks.{block_idx}.attn.q_proj.bias"
k_key = f"blocks.{block_idx}.attn.k_proj.bias"
v_key = f"blocks.{block_idx}.attn.v_proj.bias"
if all(key in intermediate_dict for key in [q_key, k_key, v_key]):
if proj_type == 'q':
qkv_bias = torch.cat([
intermediate_dict[q_key],
intermediate_dict[k_key],
intermediate_dict[v_key]
], dim=0)
out_dict[f"blocks.{block_idx}.attn.qkv.bias"] = qkv_bias
continue
else:
out_dict[k.replace('q_proj', 'qkv').replace('k_proj', 'qkv').replace('v_proj', 'qkv')] = v
# Map inner attention LayerNorm to scale norm
elif 'attn.inner_attn_ln' in k:
out_dict[k.replace('inner_attn_ln', 'norm')] = v
# Map out_proj to proj
elif 'attn.out_proj' in k:
out_dict[k.replace('out_proj', 'proj')] = v
elif 'attn.proj' in k:
out_dict[k] = v
# Handle positional embedding - skip first 2 positions (BEiT3 starts from index 2)
elif k == 'pos_embed.weight':
# BEiT3 pos_embed.weight has shape [num_patches + 3, embed_dim]
# We want [1, num_patches + 1, embed_dim] for standard ViT (cls token + patches)
out_dict['pos_embed'] = v[2:].unsqueeze(0) # Skip first 2 positions, add batch dim
# Pass through other weights unchanged
else:
out_dict[k] = v
return out_dict
def checkpoint_filter_fn(
state_dict: Dict[str, torch.Tensor],
model: VisionTransformer,
@ -1186,6 +1322,9 @@ def checkpoint_filter_fn(
state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.')
elif "mask_token" in state_dict:
state_dict = _convert_dinov2(state_dict, model)
elif any('beit3.' in k for k in state_dict.keys()):
# BEiT3 model - multimodal checkpoint with beit3.* prefix
state_dict = _convert_beit3(state_dict, model)
elif "encoder" in state_dict:
# IJEPA, vit in an 'encoder' submodule
state_dict = state_dict['encoder']
@ -2377,6 +2516,24 @@ default_cfgs = {
input_size=(3, 160, 160), crop_pct=0.95),
'test_vit4.r160_in1k': _cfg(
input_size=(3, 160, 160), crop_pct=0.95),
# BEiT3 models (remapped to VisionTransformer with scale_norm=True)
'beit3_base_patch16_224.in22k_ft_in1k': _cfg(
url='https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224_in1k.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
'beit3_base_patch16_224.in22k_indomain_ft_in1k': _cfg(
url='https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224_in1k.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
'beit3_large_patch16_224.in22k_ft_in1k': _cfg(
url='https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224_in1k.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
'beit3_large_patch16_224.in22k_indomain_ft_in1k': _cfg(
url='https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224_in1k.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
'beit3_giant_patch14_224.untrained': _cfg(
url='', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
'beit3_giant_patch14_336.untrained': _cfg(
url='', input_size=(3, 336, 336), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
}
_quick_gelu_cfgs = [n for n, c in default_cfgs.items() if c.get('notes', ()) and 'quickgelu' in c['notes'][0]]
@ -4035,6 +4192,58 @@ def test_vit4(pretrained: bool = False, **kwargs) -> VisionTransformer:
return model
@register_model
def beit3_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" BEiT3 Base model (ViT-Base size) with patch size 16x16.
Remapped to VisionTransformer with scale_norm=True.
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg'
)
model = _create_vision_transformer('beit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def beit3_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" BEiT3 Large model (ViT-Large size) with patch size 16x16.
Remapped to VisionTransformer with scale_norm=True.
"""
model_args = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg'
)
model = _create_vision_transformer('beit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def beit3_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" BEiT3 Giant model with patch size 14x14.
Remapped to VisionTransformer with scale_norm=True.
"""
model_args = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637,
scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg'
)
model = _create_vision_transformer('beit3_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def beit3_giant_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" BEiT3 Giant model with patch size 14x14 and image size 336x336.
Remapped to VisionTransformer with scale_norm=True.
"""
model_args = dict(
img_size=336, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637,
scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg'
)
model = _create_vision_transformer('beit3_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model
register_model_deprecations(__name__, {
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',