Compact _covert_beit3 fn

This commit is contained in:
Ross Wightman 2025-05-29 10:52:39 -07:00
parent 38c5f3bc63
commit 2ca94a6ce4

View File

@ -1181,87 +1181,64 @@ def _convert_aimv2(
return out_dict
def _convert_beit3(
state_dict: Dict[str, torch.Tensor],
model: VisionTransformer,
) -> Dict[str, torch.Tensor]:
"""Convert BEiT3 weights to standard VisionTransformer format."""
def _convert_beit3(state_dict: dict, model):
"""
Turn a BEiT-3 checkpoint into a standard VisionTransformer state-dict.
"""
import re
state_dict = state_dict.get("model", state_dict) # unwrap if needed
if 'model' in state_dict:
state_dict = state_dict['model']
# Prune unused
for k in ("beit3.text_embed.weight", "beit3.vision_embed.mask_token"):
state_dict.pop(k, None)
# Remove text and mask tokens (vision-only)
state_dict.pop('beit3.text_embed.weight', None)
state_dict.pop('beit3.vision_embed.mask_token', None)
# Key renaming rules
rules = [
(r"beit3\.", ""),
(r"vision_embed\.cls_token", "cls_token"),
(r"vision_embed\.", "patch_embed."),
(r"embed_positions\.", "pos_embed."),
(r"encoder\.", ""),
(r"layers\.", "blocks."),
(r"ffn_layernorm\.", "norm."), (r"ffn\.", "mlp."),
(r"self_attn_layer_norm\.", "norm1."), (r"self_attn\.", "attn."),
(r"final_layer_norm\.", "norm2."),
(r"inner_attn_ln", "norm"),
(r"out_proj", "proj"),
(r"\.A\.", "."),
]
# First pass: Apply all key transformations except qkv fusion
intermediate_dict = {}
# First pass, rename keys
tmp = {}
for k, v in state_dict.items():
# Skip B branch weights (use only A branch)
if '.B.' in k:
if ".B." in k:
continue # use branch-A only
for old, new in rules:
k = re.sub(old, new, k)
if k == "pos_embed.weight":
# strip first two positions, [1, N+1, D]
tmp["pos_embed"] = v[2:].unsqueeze(0)
else:
tmp[k] = v
# Second pass, fuse q, k, v
out, buf = {}, {}
pat = re.compile(r"blocks\.(\d+)\.attn\.(q|k|v)_proj\.(weight|bias)$")
for k, v in tmp.items():
m = pat.fullmatch(k)
if not m: # anything not q/k/v -> copy through
out[k] = v
continue
# Apply all BEiT3 key transformations in one go
if 'vision_embed.cls_token' in k:
k = 'cls_token'
else:
k = k.replace('beit3.', '')
k = k.replace('embed_positions.', 'pos_embed.')
k = k.replace('vision_embed.', 'patch_embed.')
k = k.replace('encoder.', '')
k = k.replace('layers.', 'blocks.')
k = k.replace('ffn.', 'mlp.')
k = k.replace('ffn_layernorm.', 'norm.')
k = k.replace('self_attn.', 'attn.')
k = k.replace('self_attn_layer_norm.', 'norm1.')
k = k.replace('final_layer_norm.', 'norm2.')
k = k.replace('inner_attn_ln', 'norm') # Map inner attention LayerNorm to scale norm
k = k.replace('out_proj', 'proj') # Map out_proj to proj
k = k.replace('A.', '') # Remove A branch prefix
blk, which, kind = m.groups() # block idx, 'q'/'k'/'v', 'weight'/'bias'
stash = buf.setdefault((blk, kind), {}) # Gather by block & param type
stash[which] = v
if len(stash) == 3: # Have q, k, v -> concatenate
out[f"blocks.{blk}.attn.qkv.{kind}"] = torch.cat(
[stash['q'], stash['k'], stash['v']], dim=0
)
# Handle positional embedding - skip first 2 positions (BEiT3 starts from index 2)
if k == 'pos_embed.weight':
# BEiT3 pos_embed.weight has shape [num_patches + 3, embed_dim]
# We want [1, num_patches + 1, embed_dim] for standard ViT (cls token + patches)
intermediate_dict['pos_embed'] = v[2:].unsqueeze(0) # Skip first 2 positions, add batch dim
else:
intermediate_dict[k] = v
# Second pass: Handle qkv fusion
out_dict = {}
processed_qkv = set()
for k, v in intermediate_dict.items():
# Handle attention projections - convert separate q,k,v to fused qkv
if re.match(r"blocks\.(\d+)\.attn\.[qkv]_proj\.(weight|bias)", k):
block_idx = re.search(r"blocks\.(\d+)", k).group(1)
param_type = re.search(r"\.(weight|bias)$", k).group(1)
# Only process once per block per parameter type
block_param_key = f"{block_idx}_{param_type}"
if block_param_key in processed_qkv:
continue
# Collect all three projections for this block
q_key = f"blocks.{block_idx}.attn.q_proj.{param_type}"
k_key = f"blocks.{block_idx}.attn.k_proj.{param_type}"
v_key = f"blocks.{block_idx}.attn.v_proj.{param_type}"
if all(key in intermediate_dict for key in [q_key, k_key, v_key]):
qkv_tensor = torch.cat([
intermediate_dict[q_key],
intermediate_dict[k_key],
intermediate_dict[v_key]
], dim=0)
out_dict[f"blocks.{block_idx}.attn.qkv.{param_type}"] = qkv_tensor
processed_qkv.add(block_param_key)
continue
else:
assert False
else:
out_dict[k] = v
return out_dict
return out
def checkpoint_filter_fn(