mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Initial run through remapping beit3 -> vision_transformer.py
This commit is contained in:
parent
9790fea406
commit
55e52c45ef
@ -1,5 +1,5 @@
|
||||
from .beit import *
|
||||
from .beit3 import *
|
||||
#from .beit3 import *
|
||||
from .byoanet import *
|
||||
from .byobnet import *
|
||||
from .cait import *
|
||||
|
@ -64,6 +64,7 @@ class Attention(nn.Module):
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
scale_attn_norm: bool = False,
|
||||
proj_bias: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
@ -79,6 +80,7 @@ class Attention(nn.Module):
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.norm = norm_layer(dim) if scale_attn_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
@ -102,6 +104,7 @@ class Attention(nn.Module):
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.norm(x)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
@ -130,6 +133,8 @@ class Block(nn.Module):
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
scale_attn_norm: bool = False,
|
||||
scale_mlp_norm: bool = False,
|
||||
proj_bias: bool = True,
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
@ -146,6 +151,7 @@ class Block(nn.Module):
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
scale_attn_norm=scale_attn_norm,
|
||||
proj_bias=proj_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
@ -159,6 +165,7 @@ class Block(nn.Module):
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer if scale_mlp_norm else None,
|
||||
bias=proj_bias,
|
||||
drop=proj_drop,
|
||||
)
|
||||
@ -179,6 +186,8 @@ class ResPostBlock(nn.Module):
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
scale_attn_norm: bool = False,
|
||||
scale_mlp_norm: bool = False,
|
||||
proj_bias: bool = True,
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
@ -196,6 +205,7 @@ class ResPostBlock(nn.Module):
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
scale_attn_norm=scale_attn_norm,
|
||||
proj_bias=proj_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
@ -208,6 +218,7 @@ class ResPostBlock(nn.Module):
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer if scale_mlp_norm else None,
|
||||
bias=proj_bias,
|
||||
drop=proj_drop,
|
||||
)
|
||||
@ -443,6 +454,8 @@ class VisionTransformer(nn.Module):
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
scale_attn_norm: bool = False,
|
||||
scale_mlp_norm: bool = False,
|
||||
proj_bias: bool = True,
|
||||
init_values: Optional[float] = None,
|
||||
class_token: bool = True,
|
||||
@ -563,6 +576,8 @@ class VisionTransformer(nn.Module):
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
scale_attn_norm=scale_attn_norm,
|
||||
scale_mlp_norm=scale_mlp_norm,
|
||||
proj_bias=proj_bias,
|
||||
init_values=init_values,
|
||||
proj_drop=proj_drop_rate,
|
||||
@ -1166,6 +1181,127 @@ def _convert_aimv2(
|
||||
return out_dict
|
||||
|
||||
|
||||
def _convert_beit3(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
model: VisionTransformer,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Convert BEiT3 weights to standard VisionTransformer format.
|
||||
|
||||
First applies BEiT3's own filtering (from multimodal to vision-only BEiT3 format),
|
||||
then converts from BEiT3 format to standard VisionTransformer format.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Step 1: Apply BEiT3's own checkpoint filtering logic
|
||||
# (equivalent to beit3.checkpoint_filter_fn)
|
||||
if 'model' in state_dict:
|
||||
state_dict = state_dict['model']
|
||||
|
||||
# If already processed, skip BEiT3 filtering
|
||||
if 'patch_embed.proj.weight' in state_dict:
|
||||
intermediate_dict = state_dict
|
||||
else:
|
||||
# Remove text and mask tokens (vision-only)
|
||||
state_dict.pop('beit3.text_embed.weight', None)
|
||||
state_dict.pop('beit3.vision_embed.mask_token', None)
|
||||
|
||||
intermediate_dict = {}
|
||||
|
||||
for k, v in state_dict.items():
|
||||
# Skip B branch weights (use only A branch)
|
||||
if '.B.' in k:
|
||||
continue
|
||||
elif 'vision_embed.cls_token' in k:
|
||||
k = 'cls_token'
|
||||
else:
|
||||
# Apply BEiT3's key transformations
|
||||
k = k.replace('beit3.', '')
|
||||
k = k.replace('embed_positions.', 'pos_embed.')
|
||||
k = k.replace('vision_embed.', 'patch_embed.')
|
||||
k = k.replace('encoder.', '')
|
||||
k = k.replace('layers.', 'blocks.')
|
||||
k = k.replace('ffn.', 'mlp.')
|
||||
k = k.replace('ffn_layernorm.', 'norm.')
|
||||
k = k.replace('self_attn.', 'attn.')
|
||||
k = k.replace('self_attn_layer_norm.', 'norm1.')
|
||||
k = k.replace('final_layer_norm.', 'norm2.')
|
||||
k = k.replace('A.', '') # Remove A branch prefix
|
||||
|
||||
intermediate_dict[k] = v
|
||||
|
||||
# Step 2: Convert from BEiT3 format to VisionTransformer format
|
||||
out_dict = {}
|
||||
|
||||
for k, v in intermediate_dict.items():
|
||||
# Handle attention projections - convert separate q,k,v to fused qkv
|
||||
if re.match(r"blocks\.(\d+)\.attn\.[qkv]_proj\.weight", k):
|
||||
block_idx = re.search(r"blocks\.(\d+)", k).group(1)
|
||||
proj_type = re.search(r"\.([qkv])_proj", k).group(1)
|
||||
|
||||
# Collect all three projections for this block
|
||||
q_key = f"blocks.{block_idx}.attn.q_proj.weight"
|
||||
k_key = f"blocks.{block_idx}.attn.k_proj.weight"
|
||||
v_key = f"blocks.{block_idx}.attn.v_proj.weight"
|
||||
|
||||
if all(key in intermediate_dict for key in [q_key, k_key, v_key]):
|
||||
# Only create qkv weight once when we encounter the first projection
|
||||
if proj_type == 'q':
|
||||
qkv_weight = torch.cat([
|
||||
intermediate_dict[q_key],
|
||||
intermediate_dict[k_key],
|
||||
intermediate_dict[v_key]
|
||||
], dim=0)
|
||||
out_dict[f"blocks.{block_idx}.attn.qkv.weight"] = qkv_weight
|
||||
# Skip k and v projections as they're handled with q
|
||||
continue
|
||||
else:
|
||||
# Fallback if not all projections available
|
||||
out_dict[k.replace('q_proj', 'qkv').replace('k_proj', 'qkv').replace('v_proj', 'qkv')] = v
|
||||
|
||||
# Handle attention projection biases
|
||||
elif re.match(r"blocks\.(\d+)\.attn\.[qkv]_proj\.bias", k):
|
||||
block_idx = re.search(r"blocks\.(\d+)", k).group(1)
|
||||
proj_type = re.search(r"\.([qkv])_proj", k).group(1)
|
||||
|
||||
q_key = f"blocks.{block_idx}.attn.q_proj.bias"
|
||||
k_key = f"blocks.{block_idx}.attn.k_proj.bias"
|
||||
v_key = f"blocks.{block_idx}.attn.v_proj.bias"
|
||||
|
||||
if all(key in intermediate_dict for key in [q_key, k_key, v_key]):
|
||||
if proj_type == 'q':
|
||||
qkv_bias = torch.cat([
|
||||
intermediate_dict[q_key],
|
||||
intermediate_dict[k_key],
|
||||
intermediate_dict[v_key]
|
||||
], dim=0)
|
||||
out_dict[f"blocks.{block_idx}.attn.qkv.bias"] = qkv_bias
|
||||
continue
|
||||
else:
|
||||
out_dict[k.replace('q_proj', 'qkv').replace('k_proj', 'qkv').replace('v_proj', 'qkv')] = v
|
||||
|
||||
# Map inner attention LayerNorm to scale norm
|
||||
elif 'attn.inner_attn_ln' in k:
|
||||
out_dict[k.replace('inner_attn_ln', 'norm')] = v
|
||||
|
||||
# Map out_proj to proj
|
||||
elif 'attn.out_proj' in k:
|
||||
out_dict[k.replace('out_proj', 'proj')] = v
|
||||
elif 'attn.proj' in k:
|
||||
out_dict[k] = v
|
||||
|
||||
# Handle positional embedding - skip first 2 positions (BEiT3 starts from index 2)
|
||||
elif k == 'pos_embed.weight':
|
||||
# BEiT3 pos_embed.weight has shape [num_patches + 3, embed_dim]
|
||||
# We want [1, num_patches + 1, embed_dim] for standard ViT (cls token + patches)
|
||||
out_dict['pos_embed'] = v[2:].unsqueeze(0) # Skip first 2 positions, add batch dim
|
||||
|
||||
# Pass through other weights unchanged
|
||||
else:
|
||||
out_dict[k] = v
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
def checkpoint_filter_fn(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
model: VisionTransformer,
|
||||
@ -1186,6 +1322,9 @@ def checkpoint_filter_fn(
|
||||
state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.')
|
||||
elif "mask_token" in state_dict:
|
||||
state_dict = _convert_dinov2(state_dict, model)
|
||||
elif any('beit3.' in k for k in state_dict.keys()):
|
||||
# BEiT3 model - multimodal checkpoint with beit3.* prefix
|
||||
state_dict = _convert_beit3(state_dict, model)
|
||||
elif "encoder" in state_dict:
|
||||
# IJEPA, vit in an 'encoder' submodule
|
||||
state_dict = state_dict['encoder']
|
||||
@ -2377,6 +2516,24 @@ default_cfgs = {
|
||||
input_size=(3, 160, 160), crop_pct=0.95),
|
||||
'test_vit4.r160_in1k': _cfg(
|
||||
input_size=(3, 160, 160), crop_pct=0.95),
|
||||
|
||||
# BEiT3 models (remapped to VisionTransformer with scale_norm=True)
|
||||
'beit3_base_patch16_224.in22k_ft_in1k': _cfg(
|
||||
url='https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224_in1k.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
|
||||
'beit3_base_patch16_224.in22k_indomain_ft_in1k': _cfg(
|
||||
url='https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224_in1k.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
|
||||
'beit3_large_patch16_224.in22k_ft_in1k': _cfg(
|
||||
url='https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224_in1k.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
|
||||
'beit3_large_patch16_224.in22k_indomain_ft_in1k': _cfg(
|
||||
url='https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224_in1k.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
|
||||
'beit3_giant_patch14_224.untrained': _cfg(
|
||||
url='', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
|
||||
'beit3_giant_patch14_336.untrained': _cfg(
|
||||
url='', input_size=(3, 336, 336), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
|
||||
}
|
||||
|
||||
_quick_gelu_cfgs = [n for n, c in default_cfgs.items() if c.get('notes', ()) and 'quickgelu' in c['notes'][0]]
|
||||
@ -4035,6 +4192,58 @@ def test_vit4(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def beit3_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" BEiT3 Base model (ViT-Base size) with patch size 16x16.
|
||||
Remapped to VisionTransformer with scale_norm=True.
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
||||
scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg'
|
||||
)
|
||||
model = _create_vision_transformer('beit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def beit3_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" BEiT3 Large model (ViT-Large size) with patch size 16x16.
|
||||
Remapped to VisionTransformer with scale_norm=True.
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
|
||||
scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg'
|
||||
)
|
||||
model = _create_vision_transformer('beit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def beit3_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" BEiT3 Giant model with patch size 14x14.
|
||||
Remapped to VisionTransformer with scale_norm=True.
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637,
|
||||
scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg'
|
||||
)
|
||||
model = _create_vision_transformer('beit3_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def beit3_giant_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" BEiT3 Giant model with patch size 14x14 and image size 336x336.
|
||||
Remapped to VisionTransformer with scale_norm=True.
|
||||
"""
|
||||
model_args = dict(
|
||||
img_size=336, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637,
|
||||
scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg'
|
||||
)
|
||||
model = _create_vision_transformer('beit3_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
|
||||
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',
|
||||
|
Loading…
x
Reference in New Issue
Block a user